diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index 92e0dc4..a6b46f9 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,135 @@ Abstract: Recent advancements in world models have revolutionized dynamic enviro ## TODO List - **[2025-05-30]** ✅ Release [paper](https://arxiv.org/abs/2505.22421) -- [ ] Release inference code -- [ ] Release model checkpoints +- **[2026-02-20]** ✅ Release inference code +- **[2026-02-20]** ✅ Release model checkpoints + +## Getting Started + +Environment Requirement + +We recommend first use `conda` to create virtual environment, and install needed libraries. + +```bash +conda create -n geodrive python=3.10 -y +conda activate geodrive +pip install -r requirements.txt +``` + +Then, install custom diffusers with: + +```bash +cd ./diffusers +pip install -e . +``` + +Next, install required ffmpeg: + +```bash +conda install -c conda-forge ffmpeg -y +``` + +Data Preparation + + +We use processed data from NuScenes. Please follow instruction from another folder 'monst3r' to prepare the data. + + +## 🏃🏼 Running Scripts + +Training + +Train the GeoDrive using the script: + +```bash +export MODEL_PATH="THUDM/CogVideoX-5b-I2V" +export CACHE_PATH="~/.cache" +export CONDITION_DATASET_PATH="PATH/TO/CONDITION_TRAIN" +export VIDEO_DATASET_PATH="PATH/TO/NUSCENES" +export VAL_CONDITION_DATASET_PATH=""PATH/TO/CONDITION_VAL" +export PROJECT_NAME="GeoDrive" +export RUNS_NAME="GeoDrive" +export OUTPUT_PATH="./${PROJECT_NAME}/${RUNS_NAME}" +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false +export ACCELERATE_LAUNCH_WAIT_TIMEOUT=300 +export DS_SKIP_CUDA_CHECK=1 + +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} + +torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=10086 --master_port=$MASTER_PORT \ + train.py \ + --pretrained_model_name_or_path $MODEL_PATH \ + --cache_dir $CACHE_PATH \ + --meta_file_path PATH/TO/META_TRAIN \ + --val_meta_file_path PATH/TO/META_VAL \ + --video_data_root $VIDEO_DATASET_PATH \ + --condition_data_root $CONDITION_DATASET_PATH \ + --val_condition_data_root $VAL_CONDITION_DATASET_PATH \ + --dataloader_num_workers 1 \ + --num_validation_videos 1 \ + --validation_epochs 2 \ + --seed 42 \ + --mixed_precision bf16 \ + --output_dir $OUTPUT_PATH \ + --height 480 \ + --width 720 \ + --fps 10 \ + --video_reshape_mode "center" \ + --branch_layer_num 2 \ + --train_batch_size 1 \ + --num_train_epochs 20 \ + --checkpointing_steps 2000 \ + --validating_steps 1000 \ + --gradient_accumulation_steps 1 \ + --learning_rate 1e-5 \ + --lr_scheduler cosine_with_restarts \ + --lr_warmup_steps 1000 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --noised_image_dropout 0.05 \ + --gradient_checkpointing \ + --optimizer AdamW \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --tracker_name $PROJECT_NAME \ + --runs_name $RUNS_NAME \ + --mix_train_ratio 0 \ + --first_frame_gt + +``` + +Inference + +You can run inference with the script: + +```bash +cd infer + +pretrained_model_name_or_path="THUDM/CogVideoX-5b-I2V" +checkpoint_path="PATH/TO/CKPT" +meta_file_path="PATH/TO/META" +condition_data_root="PATH/TO/CONDITION" +video_data_root="PATH/TO/NUSCENES" + +python run_validation.py \ +--checkpoint_path $checkpoint_path \ +--pretrained_model_name_or_path $pretrained_model_name_or_path \ +--meta_file_path $meta_file_path \ +--condition_data_root $condition_data_root \ +--video_data_root $video_data_root \ +--output_dir /PATH/TO/OUTPUT \ +--height 480 \ +--width 720 \ +--max_num_frames 49 \ +--mixed_precision bf16 \ +--target_frames 25 \ + +``` ## Citation diff --git a/diffusers/.gitignore b/diffusers/.gitignore new file mode 100755 index 0000000..15617d5 --- /dev/null +++ b/diffusers/.gitignore @@ -0,0 +1,178 @@ +# Initially taken from GitHub's Python gitignore file + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# tests and logs +tests/fixtures/cached_*_text.txt +logs/ +lightning_logs/ +lang_code_data/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a Python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# vscode +.vs +.vscode + +# Pycharm +.idea + +# TF code +tensorflow_code + +# Models +proc_data + +# examples +runs +/runs_old +/wandb +/examples/runs +/examples/**/*.args +/examples/rag/sweep + +# data +/data +serialization_dir + +# emacs +*.*~ +debug.env + +# vim +.*.swp + +# ctags +tags + +# pre-commit +.pre-commit* + +# .lock +*.lock + +# DS_Store (MacOS) +.DS_Store + +# RL pipelines may produce mp4 outputs +*.mp4 + +# dependencies +/transformers + +# ruff +.ruff_cache + +# wandb +wandb \ No newline at end of file diff --git a/diffusers/CITATION.cff b/diffusers/CITATION.cff new file mode 100755 index 0000000..09fc6c7 --- /dev/null +++ b/diffusers/CITATION.cff @@ -0,0 +1,52 @@ +cff-version: 1.2.0 +title: 'Diffusers: State-of-the-art diffusion models' +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Patrick + family-names: von Platen + - given-names: Suraj + family-names: Patil + - given-names: Anton + family-names: Lozhkov + - given-names: Pedro + family-names: Cuenca + - given-names: Nathan + family-names: Lambert + - given-names: Kashif + family-names: Rasul + - given-names: Mishig + family-names: Davaadorj + - given-names: Dhruv + family-names: Nair + - given-names: Sayak + family-names: Paul + - given-names: Steven + family-names: Liu + - given-names: William + family-names: Berman + - given-names: Yiyi + family-names: Xu + - given-names: Thomas + family-names: Wolf +repository-code: 'https://github.com/huggingface/diffusers' +abstract: >- + Diffusers provides pretrained diffusion models across + multiple modalities, such as vision and audio, and serves + as a modular toolbox for inference and training of + diffusion models. +keywords: + - deep-learning + - pytorch + - image-generation + - hacktoberfest + - diffusion + - text2image + - image2image + - score-based-generative-modeling + - stable-diffusion + - stable-diffusion-diffusers +license: Apache-2.0 +version: 0.12.1 diff --git a/diffusers/CODE_OF_CONDUCT.md b/diffusers/CODE_OF_CONDUCT.md new file mode 100755 index 0000000..2139079 --- /dev/null +++ b/diffusers/CODE_OF_CONDUCT.md @@ -0,0 +1,130 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall Diffusers community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Spamming issues or PRs with links to projects unrelated to this library +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +feedback@huggingface.co. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +https://www.contributor-covenant.org/version/2/1/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/diffusers/CONTRIBUTING.md b/diffusers/CONTRIBUTING.md new file mode 100755 index 0000000..049d317 --- /dev/null +++ b/diffusers/CONTRIBUTING.md @@ -0,0 +1,506 @@ + + +# How to contribute to Diffusers 🧨 + +We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it! + +Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. Join us on Discord + +Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility. + +We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered. + +## Overview + +You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to +the core library. + +In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community. + +* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR). +* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose). +* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues). +* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). +* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source). +* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples). +* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples). +* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22). +* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md). + +As said before, **all contributions are valuable to the community**. +In the following, we will explain each contribution a bit more in detail. + +For all contributions 4-9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr). + +### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord + +Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to): +- Reports of training or inference experiments in an attempt to share knowledge +- Presentation of personal projects +- Questions to non-official training examples +- Project proposals +- General feedback +- Paper summaries +- Asking for help on personal projects that build on top of the Diffusers library +- General questions +- Ethical questions regarding diffusion models +- ... + +Every question that is asked on the forum or on Discord actively encourages the community to publicly +share knowledge and might very well help a beginner in the future who has the same question you're +having. Please do pose any questions you might have. +In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from. + +**Please** keep in mind that the more effort you put into asking or answering a question, the higher +the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database. +In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. + +**NOTE about channels**: +[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago. +In addition, questions and answers posted in the forum can easily be linked to. +In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication. +While it will most likely take less time for you to get an answer to your question on Discord, your +question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers. + +### 2. Opening new issues on the GitHub issues tab + +The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of +the problems they encounter. So thank you for reporting an issue. + +Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design. + +In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR). + +**Please consider the following guidelines when opening a new issue**: +- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues). +- Please never report a new issue on another (related) issue. If another issue is highly related, please +open a new issue nevertheless and link to the related issue. +- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English. +- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version. +- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues. + +New issues usually include the following. + +#### 2.1. Reproducible, minimal bug reports + +A bug report should always have a reproducible code snippet and be as minimal and concise as possible. +This means in more detail: +- Narrow the bug down as much as you can, **do not just dump your whole code file**. +- Format your code. +- Do not include any external libraries except for Diffusers depending on them. +- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue. +- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it. +- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell. +- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible. + +For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. + +You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml). + +#### 2.2. Feature requests + +A world-class feature request addresses the following points: + +1. Motivation first: +* Is it related to a problem/frustration with the library? If so, please explain +why. Providing a code snippet that demonstrates the problem is best. +* Is it related to something you would need for a project? We'd love to hear +about it! +* Is it something you worked on and think could benefit the community? +Awesome! Tell us what problem it solved for you. +2. Write a *full paragraph* describing the feature; +3. Provide a **code snippet** that demonstrates its future use; +4. In case this is related to a paper, please attach a link; +5. Attach any additional information (drawings, screenshots, etc.) you think may help. + +You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=). + +#### 2.3 Feedback + +Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed. +If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions. + +You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=). + +#### 2.4 Technical questions + +Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on +why this part of the code is difficult to understand. + +You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml). + +#### 2.5 Proposal to add a new model, scheduler, or pipeline + +If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information: + +* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release. +* Link to any of its open-source implementation. +* Link to the model weights if they are available. + +If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget +to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it. + +You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml). + +### 3. Answering issues on the GitHub issues tab + +Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct. +Some tips to give a high-quality answer to an issue: +- Be as concise and minimal as possible. +- Stay on topic. An answer to the issue should concern the issue and only the issue. +- Provide links to code, papers, or other sources that prove or encourage your point. +- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet. + +Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great +help to the maintainers if you can answer such issues, encouraging the author of the issue to be +more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR). + +If you have verified that the issued bug report is correct and requires a correction in the source code, +please have a look at the next sections. + +For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section. + +### 4. Fixing a "Good first issue" + +*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already +explains how a potential solution should look so that it is easier to fix. +If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios: +- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it. +- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR. +- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR. + + +### 5. Contribute to the documentation + +A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly +valuable contribution**. + +Contributing to the library can have many forms: + +- Correcting spelling or grammatical errors. +- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we are very happy if you take some time to correct it. +- Correct the shape or dimensions of a docstring input or output tensor. +- Clarify documentation that is hard to understand or incorrect. +- Update outdated code examples. +- Translating the documentation to another language. + +Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source). + +Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally. + + +### 6. Contribute a community pipeline + +[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user. +Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview). +We support two types of pipelines: + +- Official Pipelines +- Community Pipelines + +Both official and community pipelines follow the same design and consist of the same type of components. + +Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code +resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). +In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested. +They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution. + +The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all +possible ways diffusion models can be used for inference, but some of them may be of interest to the community. +Officially released diffusion pipelines, +such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures +high quality of maintenance, no backward-breaking code changes, and testing. +More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library. + +To add a community pipeline, one should add a .py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline. + +An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400). + +Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors. + +Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the +core package. + +### 7. Contribute to training examples + +Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples). + +We support two types of training examples: + +- Official training examples +- Research training examples + +Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders. +The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community. +This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models. +If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author. + +Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the +training examples, it is required to clone the repository: + +```bash +git clone https://github.com/huggingface/diffusers +``` + +as well as to install all additional dependencies required for training: + +```bash +cd diffusers +pip install -r examples//requirements.txt +``` + +Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt). + +Training examples of the Diffusers library should adhere to the following philosophy: +- All the code necessary to run the examples should be found in a single Python file. +- One should be able to run the example from the command line with `python .py --args`. +- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials. + +To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like. +We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated +with Diffusers. +Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include: +- An example command on how to run the example script as shown [here e.g.](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch). +- A link to some training results (logs, models, ...) that show what the user can expect as shown [here e.g.](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5). +- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations). + +If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples. + +### 8. Fixing a "Good second issue" + +*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are +usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). +The issue description usually gives less guidance on how to fix the issue and requires +a decent understanding of the library by the interested contributor. +If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR. +Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged. + +### 9. Adding pipelines, models, schedulers + +Pipelines, models, and schedulers are the most important pieces of the Diffusers library. +They provide easy access to state-of-the-art diffusion technologies and thus allow the community to +build powerful generative AI applications. + +By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem. + +Diffusers has a couple of open feature requests for all three components - feel free to gloss over them +if you don't know yet what specific component you would like to add: +- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) +- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) + +Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) a read to better understand the design of any of the three components. Please be aware that +we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy +as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please +open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design +pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us. + +Please make sure to add links to the original codebase/paper to the PR and ideally also ping the +original author directly on the PR so that they can follow the progress and potentially help with questions. + +If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help. + +## How to write a good issue + +**The better your issue is written, the higher the chances that it will be quickly resolved.** + +1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose). +2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers". +3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data. +4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets. +5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better. +6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information. +7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library. + +## How to write a good PR + +1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged. +2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once. +3. If helpful, try to add a code snippet that displays an example of how your addition can be used. +4. The title of your pull request should be a summary of its contribution. +5. If your pull request addresses an issue, please mention the issue number in +the pull request description to make sure they are linked (and people +consulting the issue know you are working on it); +6. To indicate a work in progress please prefix the title with `[WIP]`. These +are useful to avoid duplicated work, and to differentiate it from PRs ready +to be merged; +7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue). +8. Make sure existing tests pass; +9. Add high-coverage tests. No quality testing = no merge. +- If you are adding new `@slow` tests, make sure they pass using +`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`. +CircleCI does not run the slow tests, but GitHub Actions does every night! +10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example. +11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like +[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files. +If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images +to this dataset. + +## How to open a PR + +Before writing code, we strongly advise you to search through the existing PRs or +issues to make sure that nobody is already working on the same thing. If you are +unsure, it is always a good idea to open an issue to get some feedback. + +You will need basic `git` proficiency to be able to contribute to +🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest +manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro +Git](https://git-scm.com/book/en/v2) is a very good reference. + +Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)): + +1. Fork the [repository](https://github.com/huggingface/diffusers) by +clicking on the 'Fork' button on the repository's page. This creates a copy of the code +under your GitHub user account. + +2. Clone your fork to your local disk, and add the base repository as a remote: + + ```bash + $ git clone git@github.com:/diffusers.git + $ cd diffusers + $ git remote add upstream https://github.com/huggingface/diffusers.git + ``` + +3. Create a new branch to hold your development changes: + + ```bash + $ git checkout -b a-descriptive-name-for-my-changes + ``` + +**Do not** work on the `main` branch. + +4. Set up a development environment by running the following command in a virtual environment: + + ```bash + $ pip install -e ".[dev]" + ``` + +If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the +library. + +5. Develop the features on your branch. + +As you work on the features, you should make sure that the test suite +passes. You should run the tests impacted by your changes like this: + + ```bash + $ pytest tests/.py + ``` + +Before you run the tests, please make sure you install the dependencies required for testing. You can do so +with this command: + + ```bash + $ pip install -e ".[test]" + ``` + +You can also run the full test suite with the following command, but it takes +a beefy machine to produce a result in a decent amount of time now that +Diffusers has grown a lot. Here is the command for it: + + ```bash + $ make test + ``` + +🧨 Diffusers relies on `ruff` and `isort` to format its source code +consistently. After you make changes, apply automatic style corrections and code verifications +that can't be automated in one go with: + + ```bash + $ make style + ``` + +🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality +control runs in CI, however, you can also run the same checks with: + + ```bash + $ make quality + ``` + +Once you're happy with your changes, add changed files using `git add` and +make a commit with `git commit` to record your changes locally: + + ```bash + $ git add modified_file.py + $ git commit -m "A descriptive message about your changes." + ``` + +It is a good idea to sync your copy of the code with the original +repository regularly. This way you can quickly account for changes: + + ```bash + $ git pull upstream main + ``` + +Push the changes to your account using: + + ```bash + $ git push -u origin a-descriptive-name-for-my-changes + ``` + +6. Once you are satisfied, go to the +webpage of your fork on GitHub. Click on 'Pull request' to send your changes +to the project maintainers for review. + +7. It's ok if maintainers ask you for changes. It happens to core contributors +too! So everyone can see the changes in the Pull request, work in your local +branch and push the changes to your fork. They will automatically appear in +the pull request. + +### Tests + +An extensive test suite is included to test the library behavior and several examples. Library tests can be found in +the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests). + +We like `pytest` and `pytest-xdist` because it's faster. From the root of the +repository, here's how to run tests with `pytest` for the library: + +```bash +$ python -m pytest -n auto --dist=loadfile -s -v ./tests/ +``` + +In fact, that's how `make test` is implemented! + +You can specify a smaller set of tests in order to test only the feature +you're working on. + +By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to +`yes` to run them. This will download many gigabytes of models — make sure you +have enough disk space and a good Internet connection, or a lot of patience! + +```bash +$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/ +``` + +`unittest` is fully supported, here's how to run tests with it: + +```bash +$ python -m unittest discover -s tests -t . -v +$ python -m unittest discover -s examples -t examples -v +``` + +### Syncing forked main with upstream (HuggingFace) main + +To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs, +when syncing the main branch of a forked repository, please, follow these steps: +1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main. +2. If a PR is absolutely necessary, use the following steps after checking out your branch: +```bash +$ git checkout -b your-branch-for-syncing +$ git pull --squash --no-commit upstream main +$ git commit -m '' +$ git push --set-upstream origin your-branch-for-syncing +``` + +### Style guide + +For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html). diff --git a/diffusers/LICENSE b/diffusers/LICENSE new file mode 100755 index 0000000..261eeb9 --- /dev/null +++ b/diffusers/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/diffusers/MANIFEST.in b/diffusers/MANIFEST.in new file mode 100755 index 0000000..b22fe1a --- /dev/null +++ b/diffusers/MANIFEST.in @@ -0,0 +1,2 @@ +include LICENSE +include src/diffusers/utils/model_card_template.md diff --git a/diffusers/Makefile b/diffusers/Makefile new file mode 100755 index 0000000..9af2e8b --- /dev/null +++ b/diffusers/Makefile @@ -0,0 +1,96 @@ +.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples + +# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) +export PYTHONPATH = src + +check_dirs := examples scripts src tests utils benchmarks + +modified_only_fixup: + $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) + @if test -n "$(modified_py_files)"; then \ + echo "Checking/fixing $(modified_py_files)"; \ + ruff check $(modified_py_files) --fix; \ + ruff format $(modified_py_files);\ + else \ + echo "No library .py files were modified"; \ + fi + +# Update src/diffusers/dependency_versions_table.py + +deps_table_update: + @python setup.py deps_table_update + +deps_table_check_updated: + @md5sum src/diffusers/dependency_versions_table.py > md5sum.saved + @python setup.py deps_table_update + @md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1) + @rm md5sum.saved + +# autogenerating code + +autogenerate_code: deps_table_update + +# Check that the repo is in a good state + +repo-consistency: + python utils/check_dummies.py + python utils/check_repo.py + python utils/check_inits.py + +# this target runs checks on all files + +quality: + ruff check $(check_dirs) setup.py + ruff format --check $(check_dirs) setup.py + doc-builder style src/diffusers docs/source --max_len 119 --check_only + python utils/check_doc_toc.py + +# Format source code automatically and check is there are any problems left that need manual fixing + +extra_style_checks: + python utils/custom_init_isort.py + python utils/check_doc_toc.py --fix_and_overwrite + +# this target runs checks on all files and potentially modifies some of them + +style: + ruff check $(check_dirs) setup.py --fix + ruff format $(check_dirs) setup.py + doc-builder style src/diffusers docs/source --max_len 119 + ${MAKE} autogenerate_code + ${MAKE} extra_style_checks + +# Super fast fix and check target that only works on relevant modified files since the branch was made + +fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency + +# Make marked copies of snippets of codes conform to the original + +fix-copies: + python utils/check_copies.py --fix_and_overwrite + python utils/check_dummies.py --fix_and_overwrite + +# Run tests for the library + +test: + python -m pytest -n auto --dist=loadfile -s -v ./tests/ + +# Run tests for examples + +test-examples: + python -m pytest -n auto --dist=loadfile -s -v ./examples/ + + +# Release stuff + +pre-release: + python utils/release.py + +pre-patch: + python utils/release.py --patch + +post-release: + python utils/release.py --post_release + +post-patch: + python utils/release.py --post_release --patch diff --git a/diffusers/PHILOSOPHY.md b/diffusers/PHILOSOPHY.md new file mode 100755 index 0000000..c646c61 --- /dev/null +++ b/diffusers/PHILOSOPHY.md @@ -0,0 +1,110 @@ + + +# Philosophy + +🧨 Diffusers provides **state-of-the-art** pretrained diffusion models across multiple modalities. +Its purpose is to serve as a **modular toolbox** for both inference and training. + +We aim to build a library that stands the test of time and therefore take API design very seriously. + +In a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefore, most of our design choices are based on [PyTorch's Design Principles](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy). Let's go over the most important ones: + +## Usability over Performance + +- While Diffusers has many built-in performance-enhancing features (see [Memory and Speed](https://huggingface.co/docs/diffusers/optimization/fp16)), models are always loaded with the highest precision and lowest optimization. Therefore, by default diffusion pipelines are always instantiated on CPU with float32 precision if not otherwise defined by the user. This ensures usability across different platforms and accelerators and means that no complex installations are required to run the library. +- Diffusers aims to be a **light-weight** package and therefore has very few required dependencies, but many soft dependencies that can improve performance (such as `accelerate`, `safetensors`, `onnx`, etc...). We strive to keep the library as lightweight as possible so that it can be added without much concern as a dependency on other packages. +- Diffusers prefers simple, self-explainable code over condensed, magic code. This means that short-hand code syntaxes such as lambda functions, and advanced PyTorch operators are often not desired. + +## Simple over easy + +As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library: +- We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management. +- Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible. +- Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers. +- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the UNet, and the variational autoencoder, each has their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. DreamBooth or Textual Inversion training +is very simple thanks to Diffusers' ability to separate single components of the diffusion pipeline. + +## Tweakable, contributor-friendly over abstraction + +For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself). +In short, just like Transformers does for modeling files, Diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers. +Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable. +**However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because: +- Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions. +- Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions. +- Open-source libraries rely on community contributions and therefore must build a library that is easy to contribute to. The more abstract the code, the more dependencies, the harder to read, and the harder to contribute to. Contributors simply stop contributing to very abstract libraries out of fear of breaking vital functionality. If contributing to a library cannot break other fundamental code, not only is it more inviting for potential new contributors, but it is also easier to review and contribute to multiple parts in parallel. + +At Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look +at [this blog post](https://huggingface.co/blog/transformers-design-philosophy). + +In Diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such +as [DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond). + +Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗. +We try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️ to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=). + +## Design Philosophy in Details + +Now, let's look a bit into the nitty-gritty details of the design philosophy. Diffusers essentially consists of three major classes: [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), and [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). +Let's walk through more detailed design decisions for each class. + +### Pipelines + +Pipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference. + +The following design principles are followed: +- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251). +- Pipelines all inherit from [`DiffusionPipeline`]. +- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function. +- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function. +- Pipelines should be used **only** for inference. +- Pipelines should be very readable, self-explanatory, and easy to tweak. +- Pipelines should be designed to build on top of each other and be easy to integrate into higher-level APIs. +- Pipelines are **not** intended to be feature-complete user interfaces. For feature-complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner). +- Every pipeline should have one and only one way to run it via a `__call__` method. The naming of the `__call__` arguments should be shared across all pipelines. +- Pipelines should be named after the task they are intended to solve. +- In almost all cases, novel diffusion pipelines shall be implemented in a new pipeline folder/file. + +### Models + +Models are designed as configurable toolboxes that are natural extensions of [PyTorch's Module class](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). They only partly follow the **single-file policy**. + +The following design principles are followed: +- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context. +- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc... +- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy. +- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages. +- Models all inherit from `ModelMixin` and `ConfigMixin`. +- Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain. +- Models should by default have the highest precision and lowest performance setting. +- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different. +- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work. +- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and +readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + +### Schedulers + +Schedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**. + +The following design principles are followed: +- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). +- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained. +- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper). +- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism. +- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`. +- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](./docs/source/en/using-diffusers/schedulers.md). +- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called. +- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon. +- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1). +- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box". +- In almost all cases, novel schedulers shall be implemented in a new scheduling file. diff --git a/diffusers/README.md b/diffusers/README.md new file mode 100755 index 0000000..b99ca82 --- /dev/null +++ b/diffusers/README.md @@ -0,0 +1,239 @@ + + +

+
+ +
+

+

+ GitHub + GitHub release + GitHub release + Contributor Covenant + X account +

+ +🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction). + +🤗 Diffusers offers three core components: + +- State-of-the-art [diffusion pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) that can be run in inference with just a few lines of code. +- Interchangeable noise [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview) for different diffusion speeds and output quality. +- Pretrained [models](https://huggingface.co/docs/diffusers/api/models/overview) that can be used as building blocks, and combined with schedulers, for creating your own end-to-end diffusion systems. + +## Installation + +We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/) and [Flax](https://flax.readthedocs.io/en/latest/#installation), please refer to their official documentation. + +### PyTorch + +With `pip` (official package): + +```bash +pip install --upgrade diffusers[torch] +``` + +With `conda` (maintained by the community): + +```sh +conda install -c conda-forge diffusers +``` + +### Flax + +With `pip` (official package): + +```bash +pip install --upgrade diffusers[flax] +``` + +### Apple Silicon (M1/M2) support + +Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggingface.co/docs/diffusers/optimization/mps) guide. + +## Quickstart + +Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 30,000+ checkpoints): + +```python +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16) +pipeline.to("cuda") +pipeline("An image of a squirrel in Picasso style").images[0] +``` + +You can also dig into the models and schedulers toolbox to build your own diffusion system: + +```python +from diffusers import DDPMScheduler, UNet2DModel +from PIL import Image +import torch + +scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256") +model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda") +scheduler.set_timesteps(50) + +sample_size = model.config.sample_size +noise = torch.randn((1, 3, sample_size, sample_size), device="cuda") +input = noise + +for t in scheduler.timesteps: + with torch.no_grad(): + noisy_residual = model(input, t).sample + prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample + input = prev_noisy_sample + +image = (input / 2 + 0.5).clamp(0, 1) +image = image.cpu().permute(0, 2, 3, 1).numpy()[0] +image = Image.fromarray((image * 255).round().astype("uint8")) +image +``` + +Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to launch your diffusion journey today! + +## How to navigate the documentation + +| **Documentation** | **What can I learn?** | +|---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview) | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. | +| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading_overview) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. | +| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/pipeline_overview) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. | +| [Optimization](https://huggingface.co/docs/diffusers/optimization/opt_overview) | Guides for how to optimize your diffusion model to run faster and consume less memory. | +| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. | +## Contribution + +We ❤️ contributions from the open-source community! +If you want to contribute to this library, please check out our [Contribution guide](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md). +You can look out for [issues](https://github.com/huggingface/diffusers/issues) you'd like to tackle to contribute to the library. +- See [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) for general opportunities to contribute +- See [New model/pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) to contribute exciting new diffusion models / diffusion pipelines +- See [New scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) + +Also, say 👋 in our public Discord channel Join us on Discord. We discuss the hottest trends about diffusion models, help each other with contributions, personal projects or just hang out ☕. + + +## Popular Tasks & Pipelines + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
TaskPipeline🤗 Hub
Unconditional Image Generation DDPM google/ddpm-ema-church-256
Text-to-ImageStable Diffusion Text-to-Image stable-diffusion-v1-5/stable-diffusion-v1-5
Text-to-ImageunCLIP kakaobrain/karlo-v1-alpha
Text-to-ImageDeepFloyd IF DeepFloyd/IF-I-XL-v1.0
Text-to-ImageKandinsky kandinsky-community/kandinsky-2-2-decoder
Text-guided Image-to-ImageControlNet lllyasviel/sd-controlnet-canny
Text-guided Image-to-ImageInstructPix2Pix timbrooks/instruct-pix2pix
Text-guided Image-to-ImageStable Diffusion Image-to-Image stable-diffusion-v1-5/stable-diffusion-v1-5
Text-guided Image InpaintingStable Diffusion Inpainting runwayml/stable-diffusion-inpainting
Image VariationStable Diffusion Image Variation lambdalabs/sd-image-variations-diffusers
Super ResolutionStable Diffusion Upscale stabilityai/stable-diffusion-x4-upscaler
Super ResolutionStable Diffusion Latent Upscale stabilityai/sd-x2-latent-upscaler
+ +## Popular libraries using 🧨 Diffusers + +- https://github.com/microsoft/TaskMatrix +- https://github.com/invoke-ai/InvokeAI +- https://github.com/InstantID/InstantID +- https://github.com/apple/ml-stable-diffusion +- https://github.com/Sanster/lama-cleaner +- https://github.com/IDEA-Research/Grounded-Segment-Anything +- https://github.com/ashawkey/stable-dreamfusion +- https://github.com/deep-floyd/IF +- https://github.com/bentoml/BentoML +- https://github.com/bmaltais/kohya_ss +- +14,000 other amazing GitHub repositories 💪 + +Thank you for using us ❤️. + +## Credits + +This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today: + +- @CompVis' latent diffusion models library, available [here](https://github.com/CompVis/latent-diffusion) +- @hojonathanho original DDPM implementation, available [here](https://github.com/hojonathanho/diffusion) as well as the extremely useful translation into PyTorch by @pesser, available [here](https://github.com/pesser/pytorch_diffusion) +- @ermongroup's DDIM implementation, available [here](https://github.com/ermongroup/ddim) +- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch) + +We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights. + +## Citation + +```bibtex +@misc{von-platen-etal-2022-diffusers, + author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Dhruv Nair and Sayak Paul and William Berman and Yiyi Xu and Steven Liu and Thomas Wolf}, + title = {Diffusers: State-of-the-art diffusion models}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/diffusers}} +} +``` diff --git a/diffusers/_typos.toml b/diffusers/_typos.toml new file mode 100755 index 0000000..551099f --- /dev/null +++ b/diffusers/_typos.toml @@ -0,0 +1,13 @@ +# Files for typos +# Instruction: https://github.com/marketplace/actions/typos-action#getting-started + +[default.extend-identifiers] + +[default.extend-words] +NIN="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py +nd="np" # nd may be np (numpy) +parms="parms" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py + + +[files] +extend-exclude = ["_typos.toml"] diff --git a/diffusers/pyproject.toml b/diffusers/pyproject.toml new file mode 100755 index 0000000..299865a --- /dev/null +++ b/diffusers/pyproject.toml @@ -0,0 +1,29 @@ +[tool.ruff] +line-length = 119 + +[tool.ruff.lint] +# Never enforce `E501` (line length violations). +ignore = ["C901", "E501", "E741", "F402", "F823"] +select = ["C", "E", "F", "I", "W"] + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] +"src/diffusers/utils/dummy_*.py" = ["F401"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["diffusers"] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/diffusers/scripts/__init__.py b/diffusers/scripts/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/diffusers/scripts/change_naming_configs_and_checkpoints.py b/diffusers/scripts/change_naming_configs_and_checkpoints.py new file mode 100755 index 0000000..adc1605 --- /dev/null +++ b/diffusers/scripts/change_naming_configs_and_checkpoints.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the LDM checkpoints.""" + +import argparse +import json +import os + +import torch +from transformers.file_utils import has_file + +from diffusers import UNet2DConditionModel, UNet2DModel + + +do_only_config = False +do_only_weights = True +do_only_renaming = False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo_path", + default=None, + type=str, + required=True, + help="The config json file corresponding to the architecture.", + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + config_parameters_to_change = { + "image_size": "sample_size", + "num_res_blocks": "layers_per_block", + "block_channels": "block_out_channels", + "down_blocks": "down_block_types", + "up_blocks": "up_block_types", + "downscale_freq_shift": "freq_shift", + "resnet_num_groups": "norm_num_groups", + "resnet_act_fn": "act_fn", + "resnet_eps": "norm_eps", + "num_head_channels": "attention_head_dim", + } + + key_parameters_to_change = { + "time_steps": "time_proj", + "mid": "mid_block", + "downsample_blocks": "down_blocks", + "upsample_blocks": "up_blocks", + } + + subfolder = "" if has_file(args.repo_path, "config.json") else "unet" + + with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader: + text = reader.read() + config = json.loads(text) + + if do_only_config: + for key in config_parameters_to_change.keys(): + config.pop(key, None) + + if has_file(args.repo_path, "config.json"): + model = UNet2DModel(**config) + else: + class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel + model = class_name(**config) + + if do_only_config: + model.save_config(os.path.join(args.repo_path, subfolder)) + + config = dict(model.config) + + if do_only_renaming: + for key, value in config_parameters_to_change.items(): + if key in config: + config[value] = config[key] + del config[key] + + config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]] + config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]] + + if do_only_weights: + state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin")) + + new_state_dict = {} + for param_key, param_value in state_dict.items(): + if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"): + continue + has_changed = False + for key, new_key in key_parameters_to_change.items(): + if not has_changed and param_key.split(".")[0] == key: + new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value + has_changed = True + if not has_changed: + new_state_dict[param_key] = param_value + + model.load_state_dict(new_state_dict) + model.save_pretrained(os.path.join(args.repo_path, subfolder)) diff --git a/diffusers/scripts/conversion_ldm_uncond.py b/diffusers/scripts/conversion_ldm_uncond.py new file mode 100755 index 0000000..8c22cc1 --- /dev/null +++ b/diffusers/scripts/conversion_ldm_uncond.py @@ -0,0 +1,56 @@ +import argparse + +import torch +import yaml + +from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel + + +def convert_ldm_original(checkpoint_path, config_path, output_path): + config = yaml.safe_load(config_path) + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + keys = list(state_dict.keys()) + + # extract state_dict for VQVAE + first_stage_dict = {} + first_stage_key = "first_stage_model." + for key in keys: + if key.startswith(first_stage_key): + first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key] + + # extract state_dict for UNetLDM + unet_state_dict = {} + unet_key = "model.diffusion_model." + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = state_dict[key] + + vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"] + unet_init_args = config["model"]["params"]["unet_config"]["params"] + + vqvae = VQModel(**vqvae_init_args).eval() + vqvae.load_state_dict(first_stage_dict) + + unet = UNetLDMModel(**unet_init_args).eval() + unet.load_state_dict(unet_state_dict) + + noise_scheduler = DDIMScheduler( + timesteps=config["model"]["params"]["timesteps"], + beta_schedule="scaled_linear", + beta_start=config["model"]["params"]["linear_start"], + beta_end=config["model"]["params"]["linear_end"], + clip_sample=False, + ) + + pipeline = LDMPipeline(vqvae, unet, noise_scheduler) + pipeline.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + + convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path) diff --git a/diffusers/scripts/convert_amused.py b/diffusers/scripts/convert_amused.py new file mode 100755 index 0000000..21be29d --- /dev/null +++ b/diffusers/scripts/convert_amused.py @@ -0,0 +1,523 @@ +import inspect +import os +from argparse import ArgumentParser + +import numpy as np +import torch +from muse import MaskGiTUViT, VQGANModel +from muse import PipelineMuse as OldPipelineMuse +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import VQModel +from diffusers.models.attention_processor import AttnProcessor +from diffusers.models.unets.uvit_2d import UVit2DModel +from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline +from diffusers.schedulers import AmusedScheduler + + +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) + +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(True) + +# Enable CUDNN deterministic mode +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +torch.backends.cuda.matmul.allow_tf32 = False + +device = "cuda" + + +def main(): + args = ArgumentParser() + args.add_argument("--model_256", action="store_true") + args.add_argument("--write_to", type=str, required=False, default=None) + args.add_argument("--transformer_path", type=str, required=False, default=None) + args = args.parse_args() + + transformer_path = args.transformer_path + subfolder = "transformer" + + if transformer_path is None: + if args.model_256: + transformer_path = "openMUSE/muse-256" + else: + transformer_path = ( + "../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/" + ) + subfolder = None + + old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder) + + old_transformer.to(device) + + old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae") + old_vae.to(device) + + vqvae = make_vqvae(old_vae) + + tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder") + + text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder") + text_encoder.to(device) + + transformer = make_transformer(old_transformer, args.model_256) + + scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id) + + new_pipe = AmusedPipeline( + vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler + ) + + old_pipe = OldPipelineMuse( + vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer + ) + old_pipe.to(device) + + if args.model_256: + transformer_seq_len = 256 + orig_size = (256, 256) + else: + transformer_seq_len = 1024 + orig_size = (512, 512) + + old_out = old_pipe( + "dog", + generator=torch.Generator(device).manual_seed(0), + transformer_seq_len=transformer_seq_len, + orig_size=orig_size, + timesteps=12, + )[0] + + new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0] + + old_out = np.array(old_out) + new_out = np.array(new_out) + + diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64)) + + # assert diff diff.sum() == 0 + print("skipping pipeline full equivalence check") + + print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}") + + if args.model_256: + assert diff.max() <= 3 + assert diff.sum() / diff.size < 0.7 + else: + assert diff.max() <= 1 + assert diff.sum() / diff.size < 0.4 + + if args.write_to is not None: + new_pipe.save_pretrained(args.write_to) + + +def make_transformer(old_transformer, model_256): + args = dict(old_transformer.config) + force_down_up_sample = args["force_down_up_sample"] + + signature = inspect.signature(UVit2DModel.__init__) + + args_ = { + "downsample": force_down_up_sample, + "upsample": force_down_up_sample, + "block_out_channels": args["block_out_channels"][0], + "sample_size": 16 if model_256 else 32, + } + + for s in list(signature.parameters.keys()): + if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]: + continue + + args_[s] = args[s] + + new_transformer = UVit2DModel(**args_) + new_transformer.to(device) + + new_transformer.set_attn_processor(AttnProcessor()) + + state_dict = old_transformer.state_dict() + + state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight") + state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight") + + for i in range(22): + state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.attn_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight" + ) + + state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.query.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.key.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.value.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.out.weight" + ) + + state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight" + ) + + state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.query.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.key.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.value.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.out.weight" + ) + + state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight" + ) + + wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight") + wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight") + proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0) + state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight + + state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight") + + if force_down_up_sample: + state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight") + state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight") + + state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight") + state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight") + + state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight") + + for i in range(3): + state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.norm.norm.weight" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.0.weight" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.2.beta" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.4.weight" + ) + state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight" + ) + + state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.query.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.key.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.value.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.out.weight" + ) + + state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight" + ) + + state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.norm.norm.weight" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.0.weight" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.2.beta" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.4.weight" + ) + state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight" + ) + + state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.query.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.key.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.value.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.out.weight" + ) + + state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight" + ) + + for key in list(state_dict.keys()): + if key.startswith("up_blocks.0"): + key_ = "up_block." + ".".join(key.split(".")[2:]) + state_dict[key_] = state_dict.pop(key) + + if key.startswith("down_blocks.0"): + key_ = "down_block." + ".".join(key.split(".")[2:]) + state_dict[key_] = state_dict.pop(key) + + new_transformer.load_state_dict(state_dict) + + input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device) + encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device) + cond_embeds = torch.randn((1, 768), device=old_transformer.device) + micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device) + + old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds) + old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2) + + new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds) + + # NOTE: these differences are solely due to using the geglu block that has a single linear layer of + # double output dimension instead of two different linear layers + max_diff = (old_out - new_out).abs().max() + total_diff = (old_out - new_out).abs().sum() + print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}") + assert max_diff < 0.01 + assert total_diff < 1500 + + return new_transformer + + +def make_vqvae(old_vae): + new_vae = VQModel( + act_fn="silu", + block_out_channels=[128, 256, 256, 512, 768], + down_block_types=[ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=64, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=8192, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + new_vae.to(device) + + # fmt: off + + new_state_dict = {} + + old_state_dict = old_vae.state_dict() + + new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight") + new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias") + + convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0") + convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1") + convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2") + convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3") + convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4") + + new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight") + new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias") + new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight") + new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias") + new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight") + new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias") + new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight") + new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias") + new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight") + new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias") + new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight") + new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias") + new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight") + new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias") + new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight") + new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias") + new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight") + new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias") + new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight") + new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias") + new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight") + new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias") + new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight") + new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight") + new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias") + new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight") + new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias") + new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight") + new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias") + new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight") + new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias") + new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight") + new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias") + new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight") + new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias") + new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight") + new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias") + new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight") + new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias") + new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight") + new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias") + new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight") + new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias") + + convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4") + convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3") + convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2") + convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1") + convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0") + + new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight") + new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias") + new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight") + new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias") + + # fmt: on + + assert len(old_state_dict.keys()) == 0 + + new_vae.load_state_dict(new_state_dict) + + input = torch.randn((1, 3, 512, 512), device=device) + input = input.clamp(-1, 1) + + old_encoder_output = old_vae.quant_conv(old_vae.encoder(input)) + new_encoder_output = new_vae.quant_conv(new_vae.encoder(input)) + assert (old_encoder_output == new_encoder_output).all() + + old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output)) + new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output)) + + # assert (old_decoder_output == new_decoder_output).all() + print("kipping vae decoder equivalence check") + print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}") + + old_output = old_vae(input)[0] + new_output = new_vae(input)[0] + + # assert (old_output == new_output).all() + print("skipping full vae equivalence check") + print(f"vae full diff { (old_output - new_output).float().abs().sum()}") + + return new_vae + + +def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to): + # fmt: off + + new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias") + + if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias") + + new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias") + + if f"{prefix_from}.downsample.conv.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight") + new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias") + + if f"{prefix_from}.upsample.conv.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight") + new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias") + + if f"{prefix_from}.block.2.norm1.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias") + + # fmt: on + + +if __name__ == "__main__": + main() diff --git a/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py b/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py new file mode 100755 index 0000000..21567ff --- /dev/null +++ b/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py @@ -0,0 +1,69 @@ +import argparse +import os + +import torch +from huggingface_hub import create_repo, upload_folder +from safetensors.torch import load_file, save_file + + +def convert_motion_module(original_state_dict): + converted_state_dict = {} + for k, v in original_state_dict.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint") + parser.add_argument("--output_path", type=str, required=True, help="Path to output directory") + parser.add_argument( + "--push_to_hub", + action="store_true", + default=False, + help="Whether to push the converted model to the HF or not", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + if args.ckpt_path.endswith(".safetensors"): + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") + + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + conv_state_dict = convert_motion_module(state_dict) + + # convert to new format + output_dict = {} + for module_name, params in conv_state_dict.items(): + if type(params) is not torch.Tensor: + continue + output_dict.update({f"unet.{module_name}": params}) + + os.makedirs(args.output_path, exist_ok=True) + + filepath = os.path.join(args.output_path, "diffusion_pytorch_model.safetensors") + save_file(output_dict, filepath) + + if args.push_to_hub: + repo_id = create_repo(args.output_path, exist_ok=True).repo_id + upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model") diff --git a/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py b/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py new file mode 100755 index 0000000..e188a6a --- /dev/null +++ b/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -0,0 +1,62 @@ +import argparse + +import torch +from safetensors.torch import load_file + +from diffusers import MotionAdapter + + +def convert_motion_module(original_state_dict): + converted_state_dict = {} + for k, v in original_state_dict.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--use_motion_mid_block", action="store_true") + parser.add_argument("--motion_max_seq_length", type=int, default=32) + parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int) + parser.add_argument("--save_fp16", action="store_true") + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + if args.ckpt_path.endswith(".safetensors"): + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") + + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + conv_state_dict = convert_motion_module(state_dict) + adapter = MotionAdapter( + block_out_channels=args.block_out_channels, + use_motion_mid_block=args.use_motion_mid_block, + motion_max_seq_length=args.motion_max_seq_length, + ) + # skip loading position embeddings + adapter.load_state_dict(conv_state_dict, strict=False) + adapter.save_pretrained(args.output_path) + + if args.save_fp16: + adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16") diff --git a/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py b/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py new file mode 100755 index 0000000..f246dce --- /dev/null +++ b/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py @@ -0,0 +1,83 @@ +import argparse +from typing import Dict + +import torch +import torch.nn as nn + +from diffusers import SparseControlNetModel + + +KEYS_RENAME_MAPPING = { + ".attention_blocks.0": ".attn1", + ".attention_blocks.1": ".attn2", + ".attn1.pos_encoder": ".pos_embed", + ".ff_norm": ".norm3", + ".norms.0": ".norm1", + ".norms.1": ".norm2", + ".temporal_transformer": "", +} + + +def convert(original_state_dict: Dict[str, nn.Module]) -> Dict[str, nn.Module]: + converted_state_dict = {} + + for key in list(original_state_dict.keys()): + renamed_key = key + for new_name, old_name in KEYS_RENAME_MAPPING.items(): + renamed_key = renamed_key.replace(new_name, old_name) + converted_state_dict[renamed_key] = original_state_dict.pop(key) + + return converted_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint") + parser.add_argument("--output_path", type=str, required=True, help="Path to output directory") + parser.add_argument( + "--max_motion_seq_length", + type=int, + default=32, + help="Max motion sequence length supported by the motion adapter", + ) + parser.add_argument( + "--conditioning_channels", type=int, default=4, help="Number of channels in conditioning input to controlnet" + ) + parser.add_argument( + "--use_simplified_condition_embedding", + action="store_true", + default=False, + help="Whether or not to use simplified condition embedding. When `conditioning_channels==4` i.e. latent inputs, set this to `True`. When `conditioning_channels==3` i.e. image inputs, set this to `False`", + ) + parser.add_argument( + "--save_fp16", + action="store_true", + default=False, + help="Whether or not to save model in fp16 precision along with fp32", + ) + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether or not to push saved model to the HF hub" + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict: dict = state_dict["state_dict"] + + controlnet = SparseControlNetModel( + conditioning_channels=args.conditioning_channels, + motion_max_seq_length=args.max_motion_seq_length, + use_simplified_condition_embedding=args.use_simplified_condition_embedding, + ) + + state_dict = convert(state_dict) + controlnet.load_state_dict(state_dict, strict=True) + + controlnet.save_pretrained(args.output_path, push_to_hub=args.push_to_hub) + if args.save_fp16: + controlnet = controlnet.to(dtype=torch.float16) + controlnet.save_pretrained(args.output_path, variant="fp16", push_to_hub=args.push_to_hub) diff --git a/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py new file mode 100755 index 0000000..ffb735e --- /dev/null +++ b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py @@ -0,0 +1,184 @@ +import argparse +import time +from pathlib import Path +from typing import Any, Dict, Literal + +import torch + +from diffusers import AsymmetricAutoencoderKL + + +ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + "down_block_out_channels": [128, 256, 512, 512], + "layers_per_down_block": 2, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "up_block_out_channels": [192, 384, 768, 768], + "layers_per_up_block": 3, + "act_fn": "silu", + "latent_channels": 4, + "norm_num_groups": 32, + "sample_size": 256, + "scaling_factor": 0.18215, +} + +ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + "down_block_out_channels": [128, 256, 512, 512], + "layers_per_down_block": 2, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "up_block_out_channels": [256, 512, 1024, 1024], + "layers_per_up_block": 5, + "act_fn": "silu", + "latent_channels": 4, + "norm_num_groups": 32, + "sample_size": 256, + "scaling_factor": 0.18215, +} + + +def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]: + converted_state_dict = {} + for k, v in original_state_dict.items(): + if k.startswith("encoder."): + converted_state_dict[ + k.replace("encoder.down.", "encoder.down_blocks.") + .replace("encoder.mid.", "encoder.mid_block.") + .replace("encoder.norm_out.", "encoder.conv_norm_out.") + .replace(".downsample.", ".downsamplers.0.") + .replace(".nin_shortcut.", ".conv_shortcut.") + .replace(".block.", ".resnets.") + .replace(".block_1.", ".resnets.0.") + .replace(".block_2.", ".resnets.1.") + .replace(".attn_1.k.", ".attentions.0.to_k.") + .replace(".attn_1.q.", ".attentions.0.to_q.") + .replace(".attn_1.v.", ".attentions.0.to_v.") + .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.") + .replace(".attn_1.norm.", ".attentions.0.group_norm.") + ] = v + elif k.startswith("decoder.") and "up_layers" not in k: + converted_state_dict[ + k.replace("decoder.encoder.", "decoder.condition_encoder.") + .replace(".norm_out.", ".conv_norm_out.") + .replace(".up.0.", ".up_blocks.3.") + .replace(".up.1.", ".up_blocks.2.") + .replace(".up.2.", ".up_blocks.1.") + .replace(".up.3.", ".up_blocks.0.") + .replace(".block.", ".resnets.") + .replace("mid", "mid_block") + .replace(".0.upsample.", ".0.upsamplers.0.") + .replace(".1.upsample.", ".1.upsamplers.0.") + .replace(".2.upsample.", ".2.upsamplers.0.") + .replace(".nin_shortcut.", ".conv_shortcut.") + .replace(".block_1.", ".resnets.0.") + .replace(".block_2.", ".resnets.1.") + .replace(".attn_1.k.", ".attentions.0.to_k.") + .replace(".attn_1.q.", ".attentions.0.to_q.") + .replace(".attn_1.v.", ".attentions.0.to_v.") + .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.") + .replace(".attn_1.norm.", ".attentions.0.group_norm.") + ] = v + elif k.startswith("quant_conv."): + converted_state_dict[k] = v + elif k.startswith("post_quant_conv."): + converted_state_dict[k] = v + else: + print(f" skipping key `{k}`") + # fix weights shape + for k, v in converted_state_dict.items(): + if ( + (k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0")) + and k.endswith("weight") + and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k) + ): + converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0] + + return converted_state_dict + + +def get_asymmetric_autoencoder_kl_from_original_checkpoint( + scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device +) -> AsymmetricAutoencoderKL: + print("Loading original state_dict") + original_state_dict = torch.load(original_checkpoint_path, map_location=map_location) + original_state_dict = original_state_dict["state_dict"] + print("Converting state_dict") + converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict) + kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG + print("Initializing AsymmetricAutoencoderKL model") + asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs) + print("Loading weight from converted state_dict") + asymmetric_autoencoder_kl.load_state_dict(converted_state_dict) + asymmetric_autoencoder_kl.eval() + print("AsymmetricAutoencoderKL successfully initialized") + return asymmetric_autoencoder_kl + + +if __name__ == "__main__": + start = time.time() + parser = argparse.ArgumentParser() + parser.add_argument( + "--scale", + default=None, + type=str, + required=True, + help="Asymmetric VQGAN scale: `1.5` or `2`", + ) + parser.add_argument( + "--original_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the original Asymmetric VQGAN checkpoint", + ) + parser.add_argument( + "--output_path", + default=None, + type=str, + required=True, + help="Path to save pretrained AsymmetricAutoencoderKL model", + ) + parser.add_argument( + "--map_location", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading the checkpoint", + ) + args = parser.parse_args() + + assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`" + assert Path(args.original_checkpoint_path).is_file() + + asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint( + scale=args.scale, + original_checkpoint_path=args.original_checkpoint_path, + map_location=torch.device(args.map_location), + ) + print("Saving pretrained AsymmetricAutoencoderKL") + asymmetric_autoencoder_kl.save_pretrained(args.output_path) + print(f"Done in {time.time() - start:.2f} seconds") diff --git a/diffusers/scripts/convert_aura_flow_to_diffusers.py b/diffusers/scripts/convert_aura_flow_to_diffusers.py new file mode 100755 index 0000000..74c34f4 --- /dev/null +++ b/diffusers/scripts/convert_aura_flow_to_diffusers.py @@ -0,0 +1,131 @@ +import argparse + +import torch +from huggingface_hub import hf_hub_download + +from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel + + +def load_original_state_dict(args): + model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin") + state_dict = torch.load(model_pt, map_location="cpu") + return state_dict + + +def calculate_layers(state_dict_keys, key_prefix): + dit_layers = set() + for k in state_dict_keys: + if key_prefix in k: + dit_layers.add(int(k.split(".")[2])) + print(f"{key_prefix}: {len(dit_layers)}") + return len(dit_layers) + + +# similar to SD3 but only for the last norm layer +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_transformer(state_dict): + converted_state_dict = {} + state_dict_keys = list(state_dict.keys()) + + converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens") + converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding") + converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight") + converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias") + + converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight") + converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias") + converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight") + converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias") + + converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight") + + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + # MMDiT blocks 🎸. + for i in range(mmdit_layers): + # feed-forward + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for orig_k, diffuser_k in path_mapping.items(): + for k, v in weight_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop( + f"model.double_layers.{i}.{orig_k}.{k}.weight" + ) + + # norms + path_mapping = {"modX": "norm1", "modC": "norm1_context"} + for orig_k, diffuser_k in path_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop( + f"model.double_layers.{i}.{orig_k}.1.weight" + ) + + # attns + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} + for attn_mapping in [x_attn_mapping, context_attn_mapping]: + for k, v in attn_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop( + f"model.double_layers.{i}.attn.{k}.weight" + ) + + # Single-DiT blocks. + for i in range(single_dit_layers): + # feed-forward + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for k, v in mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop( + f"model.single_layers.{i}.mlp.{k}.weight" + ) + + # norms + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop( + f"model.single_layers.{i}.modCX.1.weight" + ) + + # attns + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + for k, v in x_attn_mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop( + f"model.single_layers.{i}.attn.{k}.weight" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None) + + return converted_state_dict + + +@torch.no_grad() +def populate_state_dict(args): + original_state_dict = load_original_state_dict(args) + state_dict_keys = list(original_state_dict.keys()) + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + converted_state_dict = convert_transformer(original_state_dict) + model_diffusers = AuraFlowTransformer2DModel( + num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers + ) + model_diffusers.load_state_dict(converted_state_dict, strict=True) + + return model_diffusers + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str) + parser.add_argument("--dump_path", default="aura-flow", type=str) + parser.add_argument("--hub_id", default=None, type=str) + args = parser.parse_args() + + model_diffusers = populate_state_dict(args) + model_diffusers.save_pretrained(args.dump_path) + if args.hub_id is not None: + model_diffusers.push_to_hub(args.hub_id) diff --git a/diffusers/scripts/convert_blipdiffusion_to_diffusers.py b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py new file mode 100755 index 0000000..03cf67e --- /dev/null +++ b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py @@ -0,0 +1,343 @@ +""" +This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main. +""" + +import argparse +import os +import tempfile + +import torch +from lavis.models import load_model_and_preprocess +from transformers import CLIPTokenizer +from transformers.models.blip_2.configuration_blip_2 import Blip2Config + +from diffusers import ( + AutoencoderKL, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.pipelines import BlipDiffusionPipeline +from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor +from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel +from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel + + +BLIP2_CONFIG = { + "vision_config": { + "hidden_size": 1024, + "num_hidden_layers": 23, + "num_attention_heads": 16, + "image_size": 224, + "patch_size": 14, + "intermediate_size": 4096, + "hidden_act": "quick_gelu", + }, + "qformer_config": { + "cross_attention_frequency": 1, + "encoder_hidden_size": 1024, + "vocab_size": 30523, + }, + "num_query_tokens": 16, +} +blip2config = Blip2Config(**BLIP2_CONFIG) + + +def qformer_model_from_original_config(): + qformer = Blip2QFormerModel(blip2config) + return qformer + + +def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix): + embeddings = {} + embeddings.update( + { + f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[ + f"{original_embeddings_prefix}.word_embeddings.weight" + ] + } + ) + embeddings.update( + { + f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[ + f"{original_embeddings_prefix}.position_embeddings.weight" + ] + } + ) + embeddings.update( + {f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]} + ) + embeddings.update( + {f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]} + ) + return embeddings + + +def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix): + proj_layer = {} + proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]}) + proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]}) + return proj_layer + + +def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix): + attention = {} + attention.update( + { + f"{diffuser_attention_prefix}.attention.query.weight": model[ + f"{original_attention_prefix}.self.query.weight" + ] + } + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]} + ) + attention.update( + { + f"{diffuser_attention_prefix}.attention.value.weight": model[ + f"{original_attention_prefix}.self.value.weight" + ] + } + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]} + ) + attention.update( + { + f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[ + f"{original_attention_prefix}.output.LayerNorm.weight" + ] + } + ) + attention.update( + { + f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[ + f"{original_attention_prefix}.output.LayerNorm.bias" + ] + } + ) + return attention + + +def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix): + output_layers = {} + output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]}) + output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]}) + output_layers.update( + {f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]} + ) + output_layers.update( + {f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]} + ) + return output_layers + + +def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix): + encoder = {} + for i in range(blip2config.qformer_config.num_hidden_layers): + encoder.update( + attention_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention" + ) + ) + encoder.update( + attention_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention" + ) + ) + + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[ + f"{original_encoder_prefix}.{i}.intermediate.dense.weight" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[ + f"{original_encoder_prefix}.{i}.intermediate.dense.bias" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[ + f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[ + f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias" + ] + } + ) + + encoder.update( + output_layers_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output" + ) + ) + encoder.update( + output_layers_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query" + ) + ) + return encoder + + +def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix): + visual_encoder_layer = {} + + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]}) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]} + ) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]}) + + return visual_encoder_layer + + +def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix): + visual_encoder = {} + + visual_encoder.update( + { + f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"] + .unsqueeze(0) + .unsqueeze(0) + } + ) + visual_encoder.update( + { + f"{diffuser_prefix}.embeddings.position_embedding": model[ + f"{original_prefix}.positional_embedding" + ].unsqueeze(0) + } + ) + visual_encoder.update( + {f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]} + ) + visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]}) + visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]}) + + for i in range(blip2config.vision_config.num_hidden_layers): + visual_encoder.update( + visual_encoder_layer_from_original_checkpoint( + model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}" + ) + ) + + visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]}) + visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]}) + + return visual_encoder + + +def qformer_original_checkpoint_to_diffusers_checkpoint(model): + qformer_checkpoint = {} + qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings")) + qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]}) + qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer")) + qformer_checkpoint.update( + encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer") + ) + qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder")) + return qformer_checkpoint + + +def get_qformer(model): + print("loading qformer") + + qformer = qformer_model_from_original_config() + qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model) + + load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer) + + print("done loading qformer") + return qformer + + +def load_checkpoint_to_model(checkpoint, model): + with tempfile.NamedTemporaryFile(delete=False) as file: + torch.save(checkpoint, file.name) + del checkpoint + model.load_state_dict(torch.load(file.name), strict=False) + + os.remove(file.name) + + +def save_blip_diffusion_model(model, args): + qformer = get_qformer(model) + qformer.eval() + + text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + vae.eval() + text_encoder.eval() + scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + skip_prk_steps=True, + ) + tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") + image_processor = BlipImageProcessor() + blip_diffusion = BlipDiffusionPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + blip_diffusion.save_pretrained(args.checkpoint_path) + + +def main(args): + model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True) + save_blip_diffusion_model(model.state_dict(), args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/diffusers/scripts/convert_cogvideox_to_diffusers.py b/diffusers/scripts/convert_cogvideox_to_diffusers.py new file mode 100755 index 0000000..4343eaf --- /dev/null +++ b/diffusers/scripts/convert_cogvideox_to_diffusers.py @@ -0,0 +1,290 @@ +import argparse +from typing import Any, Dict + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) + + +def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): + to_q_key = key.replace("query_key_value", "to_q") + to_k_key = key.replace("query_key_value", "to_k") + to_v_key = key.replace("query_key_value", "to_v") + to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) + state_dict[to_q_key] = to_q + state_dict[to_k_key] = to_k + state_dict[to_v_key] = to_v + state_dict.pop(key) + + +def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, weight_or_bias = key.split(".")[-2:] + + if "query" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" + elif "key" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" + + state_dict[new_key] = state_dict.pop(key) + + +def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, _, weight_or_bias = key.split(".")[-3:] + + weights_or_biases = state_dict[key].chunk(12, dim=0) + norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) + norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) + + norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" + state_dict[norm1_key] = norm1_weights_or_biases + + norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" + state_dict[norm2_key] = norm2_weights_or_biases + + state_dict.pop(key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): + state_dict.pop(key) + + +def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): + key_split = key.split(".") + layer_index = int(key_split[2]) + replace_layer_index = 4 - 1 - layer_index + + key_split[1] = "up_blocks" + key_split[2] = str(replace_layer_index) + new_key = ".".join(key_split) + + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "transformer.final_layernorm": "norm_final", + "transformer": "transformer_blocks", + "attention": "attn1", + "mlp": "ff.net", + "dense_h_to_4h": "0.proj", + "dense_4h_to_h": "2", + ".layers": "", + "dense": "to_out.0", + "input_layernorm": "norm1.norm", + "post_attn1_layernorm": "norm2.norm", + "time_embed.0": "time_embedding.linear_1", + "time_embed.2": "time_embedding.linear_2", + "mixins.patch_embed": "patch_embed", + "mixins.final_layer.norm_final": "norm_out.norm", + "mixins.final_layer.linear": "proj_out", + "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", + "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "query_key_value": reassign_query_key_value_inplace, + "query_layernorm_list": reassign_query_key_layernorm_inplace, + "key_layernorm_list": reassign_query_key_layernorm_inplace, + "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, + "embed_tokens": remove_keys_inplace, + "freqs_sin": remove_keys_inplace, + "freqs_cos": remove_keys_inplace, + "position_embedding": remove_keys_inplace, +} + +VAE_KEYS_RENAME_DICT = { + "block.": "resnets.", + "down.": "down_blocks.", + "downsample": "downsamplers.0", + "upsample": "upsamplers.0", + "nin_shortcut": "conv_shortcut", + "encoder.mid.block_1": "encoder.mid_block.resnets.0", + "encoder.mid.block_2": "encoder.mid_block.resnets.1", + "decoder.mid.block_1": "decoder.mid_block.resnets.0", + "decoder.mid.block_2": "decoder.mid_block.resnets.1", +} + +VAE_SPECIAL_KEYS_REMAP = { + "loss": remove_keys_inplace, + "up.": replace_up_keys_inplace, +} + +TOKENIZER_MAX_LENGTH = 226 + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def convert_transformer( + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + i2v: bool, + dtype: torch.dtype, +): + PREFIX_KEY = "model.diffusion_model." + + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + transformer = CogVideoXTransformer3DModel( + in_channels=32 if i2v else 16, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + use_rotary_positional_embeddings=use_rotary_positional_embeddings, + use_learned_positional_embeddings=i2v, + ).to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[len(PREFIX_KEY) :] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True) + return transformer + + +def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" + ) + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) + # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 + parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") + # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 + parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") + # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True + parser.add_argument( + "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" + ) + # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 + parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") + # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 + parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") + parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + transformer = None + vae = None + + if args.fp16 and args.bf16: + raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") + + dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer( + args.transformer_ckpt_path, + args.num_layers, + args.num_attention_heads, + args.use_rotary_positional_embeddings, + args.i2v, + dtype, + ) + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # Apparently, the conversion does not work anymore without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + scheduler = CogVideoXDDIMScheduler.from_config( + { + "snr_shift_scale": args.snr_shift_scale, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": True, + "set_alpha_to_one": True, + "timestep_spacing": "trailing", + } + ) + if args.i2v: + pipeline_cls = CogVideoXImageToVideoPipeline + else: + pipeline_cls = CogVideoXPipeline + + pipe = pipeline_cls( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + + if args.fp16: + pipe = pipe.to(dtype=torch.float16) + if args.bf16: + pipe = pipe.to(dtype=torch.bfloat16) + + # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird + # for users to specify variant when the default is not fp32 and they want to run with the correct default (which + # is either fp16/bf16 here). + + # This is necessary This is necessary for users with insufficient memory, + # such as those using Colab and notebooks, as it can save some memory used for model loading. + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) diff --git a/diffusers/scripts/convert_consistency_decoder.py b/diffusers/scripts/convert_consistency_decoder.py new file mode 100755 index 0000000..0cb5fc5 --- /dev/null +++ b/diffusers/scripts/convert_consistency_decoder.py @@ -0,0 +1,1128 @@ +import math +import os +import urllib +import warnings +from argparse import ArgumentParser + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub.utils import insecure_hashlib +from safetensors.torch import load_file as stl +from tqdm import tqdm + +from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel +from diffusers.models.autoencoders.vae import Encoder +from diffusers.models.embeddings import TimestepEmbedding +from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D + + +args = ArgumentParser() +args.add_argument("--save_pretrained", required=False, default=None, type=str) +args.add_argument("--test_image", required=True, type=str) +args = args.parse_args() + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """ + res = arr[timesteps].float() + dims_to_append = len(broadcast_shape) - len(res.shape) + return res[(...,) + (None,) * dims_to_append] + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45 + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas) + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +class ConsistencyDecoder: + def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")): + self.n_distilled_steps = 64 + download_target = _download( + "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt", + download_root, + ) + self.ckpt = torch.jit.load(download_target).to(device) + self.device = device + sigma_data = 0.5 + betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) + sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) + self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2) + self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 + self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 + + @staticmethod + def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True): + with torch.no_grad(): + space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor") + rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space + if truncate_start: + rounded_timesteps[rounded_timesteps == total_timesteps] -= space + else: + rounded_timesteps[rounded_timesteps == total_timesteps] -= space + rounded_timesteps[rounded_timesteps == 0] += space + return rounded_timesteps + + @staticmethod + def ldm_transform_latent(z, extra_scale_factor=1): + channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294] + channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034] + + if len(z.shape) != 4: + raise ValueError() + + z = z * 0.18215 + channels = [z[:, i] for i in range(z.shape[1])] + + channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)] + return torch.stack(channels, dim=1) + + @torch.no_grad() + def __call__( + self, + features: torch.Tensor, + schedule=[1.0, 0.5], + generator=None, + ): + features = self.ldm_transform_latent(features) + ts = self.round_timesteps( + torch.arange(0, 1024), + 1024, + self.n_distilled_steps, + truncate_start=False, + ) + shape = ( + features.size(0), + 3, + 8 * features.size(2), + 8 * features.size(3), + ) + x_start = torch.zeros(shape, device=features.device, dtype=features.dtype) + schedule_timesteps = [int((1024 - 1) * s) for s in schedule] + for i in schedule_timesteps: + t = ts[i].item() + t_ = torch.tensor([t] * features.shape[0]).to(self.device) + # noise = torch.randn_like(x_start) + noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device) + x_start = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise + ) + c_in = _extract_into_tensor(self.c_in, t_, x_start.shape) + + import torch.nn.functional as F + + from diffusers import UNet2DModel + + if isinstance(self.ckpt, UNet2DModel): + input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1) + model_output = self.ckpt(input, t_).sample + else: + model_output = self.ckpt(c_in * x_start, t_, features=features) + + B, C = x_start.shape[:2] + model_output, _ = torch.split(model_output, C, dim=1) + pred_xstart = ( + _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output + + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start + ).clamp(-1, 1) + x_start = pred_xstart + return x_start + + +def save_image(image, name): + import numpy as np + from PIL import Image + + image = image[0].cpu().numpy() + image = (image + 1.0) * 127.5 + image = image.clip(0, 255).astype(np.uint8) + image = Image.fromarray(image.transpose(1, 2, 0)) + image.save(name) + + +def load_image(uri, size=None, center_crop=False): + import numpy as np + from PIL import Image + + image = Image.open(uri) + if center_crop: + image = image.crop( + ( + (image.width - min(image.width, image.height)) // 2, + (image.height - min(image.width, image.height)) // 2, + (image.width + min(image.width, image.height)) // 2, + (image.height + min(image.width, image.height)) // 2, + ) + ) + if size is not None: + image = image.resize(size) + image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float() + image = image / 127.5 - 1.0 + return image + + +class TimestepEmbedding_(nn.Module): + def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: + super().__init__() + self.emb = nn.Embedding(n_time, n_emb) + self.f_1 = nn.Linear(n_emb, n_out) + self.f_2 = nn.Linear(n_out, n_out) + + def forward(self, x) -> torch.Tensor: + x = self.emb(x) + x = self.f_1(x) + x = F.silu(x) + return self.f_2(x) + + +class ImageEmbedding(nn.Module): + def __init__(self, in_channels=7, out_channels=320) -> None: + super().__init__() + self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x) -> torch.Tensor: + return self.f(x) + + +class ImageUnembedding(nn.Module): + def __init__(self, in_channels=320, out_channels=6) -> None: + super().__init__() + self.gn = nn.GroupNorm(32, in_channels) + self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x) -> torch.Tensor: + return self.f(F.silu(self.gn(x))) + + +class ConvResblock(nn.Module): + def __init__(self, in_features=320, out_features=320) -> None: + super().__init__() + self.f_t = nn.Linear(1280, out_features * 2) + + self.gn_1 = nn.GroupNorm(32, in_features) + self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) + + self.gn_2 = nn.GroupNorm(32, out_features) + self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) + + skip_conv = in_features != out_features + self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity() + + def forward(self, x, t): + x_skip = x + t = self.f_t(F.silu(t)) + t = t.chunk(2, dim=1) + t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 + t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) + + gn_1 = F.silu(self.gn_1(x)) + f_1 = self.f_1(gn_1) + + gn_2 = self.gn_2(f_1) + + return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) + + +# Also ConvResblock +class Downsample(nn.Module): + def __init__(self, in_channels=320) -> None: + super().__init__() + self.f_t = nn.Linear(1280, in_channels * 2) + + self.gn_1 = nn.GroupNorm(32, in_channels) + self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.gn_2 = nn.GroupNorm(32, in_channels) + + self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x, t) -> torch.Tensor: + x_skip = x + + t = self.f_t(F.silu(t)) + t_1, t_2 = t.chunk(2, dim=1) + t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 + t_2 = t_2.unsqueeze(2).unsqueeze(3) + + gn_1 = F.silu(self.gn_1(x)) + avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) + + f_1 = self.f_1(avg_pool2d) + gn_2 = self.gn_2(f_1) + + f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) + + return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) + + +# Also ConvResblock +class Upsample(nn.Module): + def __init__(self, in_channels=1024) -> None: + super().__init__() + self.f_t = nn.Linear(1280, in_channels * 2) + + self.gn_1 = nn.GroupNorm(32, in_channels) + self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.gn_2 = nn.GroupNorm(32, in_channels) + + self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x, t) -> torch.Tensor: + x_skip = x + + t = self.f_t(F.silu(t)) + t_1, t_2 = t.chunk(2, dim=1) + t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 + t_2 = t_2.unsqueeze(2).unsqueeze(3) + + gn_1 = F.silu(self.gn_1(x)) + upsample = F.upsample_nearest(gn_1, scale_factor=2) + f_1 = self.f_1(upsample) + gn_2 = self.gn_2(f_1) + + f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) + + return f_2 + F.upsample_nearest(x_skip, scale_factor=2) + + +class ConvUNetVAE(nn.Module): + def __init__(self) -> None: + super().__init__() + self.embed_image = ImageEmbedding() + self.embed_time = TimestepEmbedding_() + + down_0 = nn.ModuleList( + [ + ConvResblock(320, 320), + ConvResblock(320, 320), + ConvResblock(320, 320), + Downsample(320), + ] + ) + down_1 = nn.ModuleList( + [ + ConvResblock(320, 640), + ConvResblock(640, 640), + ConvResblock(640, 640), + Downsample(640), + ] + ) + down_2 = nn.ModuleList( + [ + ConvResblock(640, 1024), + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + Downsample(1024), + ] + ) + down_3 = nn.ModuleList( + [ + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + ] + ) + self.down = nn.ModuleList( + [ + down_0, + down_1, + down_2, + down_3, + ] + ) + + self.mid = nn.ModuleList( + [ + ConvResblock(1024, 1024), + ConvResblock(1024, 1024), + ] + ) + + up_3 = nn.ModuleList( + [ + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + Upsample(1024), + ] + ) + up_2 = nn.ModuleList( + [ + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 * 2, 1024), + ConvResblock(1024 + 640, 1024), + Upsample(1024), + ] + ) + up_1 = nn.ModuleList( + [ + ConvResblock(1024 + 640, 640), + ConvResblock(640 * 2, 640), + ConvResblock(640 * 2, 640), + ConvResblock(320 + 640, 640), + Upsample(640), + ] + ) + up_0 = nn.ModuleList( + [ + ConvResblock(320 + 640, 320), + ConvResblock(320 * 2, 320), + ConvResblock(320 * 2, 320), + ConvResblock(320 * 2, 320), + ] + ) + self.up = nn.ModuleList( + [ + up_0, + up_1, + up_2, + up_3, + ] + ) + + self.output = ImageUnembedding() + + def forward(self, x, t, features) -> torch.Tensor: + converted = hasattr(self, "converted") and self.converted + + x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1) + + if converted: + t = self.time_embedding(self.time_proj(t)) + else: + t = self.embed_time(t) + + x = self.embed_image(x) + + skips = [x] + for i, down in enumerate(self.down): + if converted and i in [0, 1, 2, 3]: + x, skips_ = down(x, t) + for skip in skips_: + skips.append(skip) + else: + for block in down: + x = block(x, t) + skips.append(x) + print(x.float().abs().sum()) + + if converted: + x = self.mid(x, t) + else: + for i in range(2): + x = self.mid[i](x, t) + print(x.float().abs().sum()) + + for i, up in enumerate(self.up[::-1]): + if converted and i in [0, 1, 2, 3]: + skip_4 = skips.pop() + skip_3 = skips.pop() + skip_2 = skips.pop() + skip_1 = skips.pop() + skips_ = (skip_1, skip_2, skip_3, skip_4) + x = up(x, skips_, t) + else: + for block in up: + if isinstance(block, ConvResblock): + x = torch.concat([x, skips.pop()], dim=1) + x = block(x, t) + + return self.output(x) + + +def rename_state_dict_key(k): + k = k.replace("blocks.", "") + for i in range(5): + k = k.replace(f"down_{i}_", f"down.{i}.") + k = k.replace(f"conv_{i}.", f"{i}.") + k = k.replace(f"up_{i}_", f"up.{i}.") + k = k.replace(f"mid_{i}", f"mid.{i}") + k = k.replace("upsamp.", "4.") + k = k.replace("downsamp.", "3.") + k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias") + k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias") + k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias") + k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias") + k = k.replace("f.w", "f.weight").replace("f.b", "f.bias") + k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias") + k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias") + k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias") + return k + + +def rename_state_dict(sd, embedding): + sd = {rename_state_dict_key(k): v for k, v in sd.items()} + sd["embed_time.emb.weight"] = embedding["weight"] + return sd + + +# encode with stable diffusion vae +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +pipe.vae.cuda() + +# construct original decoder with jitted model +decoder_consistency = ConsistencyDecoder(device="cuda:0") + +# construct UNet code, overwrite the decoder with conv_unet_vae +model = ConvUNetVAE() +model.load_state_dict( + rename_state_dict( + stl("consistency_decoder.safetensors"), + stl("embedding.safetensors"), + ) +) +model = model.cuda() + +decoder_consistency.ckpt = model + +image = load_image(args.test_image, size=(256, 256), center_crop=True) +latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample() + +# decode with gan +sample_gan = pipe.vae.decode(latent).sample.detach() +save_image(sample_gan, "gan.png") + +# decode with conv_unet_vae +sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) +save_image(sample_consistency_orig, "con_orig.png") + + +########### conversion + +print("CONVERSION") + +print("DOWN BLOCK ONE") + +block_one_sd_orig = model.down[0].state_dict() +block_one_sd_new = {} + +for i in range(3): + block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight") + block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias") + block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight") + block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias") + block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight") + block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias") + block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight") + block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias") + block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight") + block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias") + +block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight") +block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias") +block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight") +block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias") +block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight") +block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias") +block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight") +block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias") +block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight") +block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias") + +assert len(block_one_sd_orig) == 0 + +block_one = ResnetDownsampleBlock2D( + in_channels=320, + out_channels=320, + temb_channels=1280, + num_layers=3, + add_downsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_one.load_state_dict(block_one_sd_new) + +print("DOWN BLOCK TWO") + +block_two_sd_orig = model.down[1].state_dict() +block_two_sd_new = {} + +for i in range(3): + block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight") + block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias") + block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight") + block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias") + block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight") + block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias") + block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight") + block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias") + block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight") + block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias") + + if i == 0: + block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight") + block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias") + +block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight") +block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias") +block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight") +block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias") +block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight") +block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias") +block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight") +block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias") +block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight") +block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias") + +assert len(block_two_sd_orig) == 0 + +block_two = ResnetDownsampleBlock2D( + in_channels=320, + out_channels=640, + temb_channels=1280, + num_layers=3, + add_downsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_two.load_state_dict(block_two_sd_new) + +print("DOWN BLOCK THREE") + +block_three_sd_orig = model.down[2].state_dict() +block_three_sd_new = {} + +for i in range(3): + block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight") + block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias") + block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight") + block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias") + block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight") + block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias") + block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight") + block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias") + block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight") + block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias") + + if i == 0: + block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight") + block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias") + +block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight") +block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias") +block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight") +block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias") +block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight") +block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias") +block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight") +block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias") +block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight") +block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias") + +assert len(block_three_sd_orig) == 0 + +block_three = ResnetDownsampleBlock2D( + in_channels=640, + out_channels=1024, + temb_channels=1280, + num_layers=3, + add_downsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_three.load_state_dict(block_three_sd_new) + +print("DOWN BLOCK FOUR") + +block_four_sd_orig = model.down[3].state_dict() +block_four_sd_new = {} + +for i in range(3): + block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight") + block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias") + block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight") + block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias") + block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight") + block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias") + block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight") + block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias") + block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight") + block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias") + +assert len(block_four_sd_orig) == 0 + +block_four = ResnetDownsampleBlock2D( + in_channels=1024, + out_channels=1024, + temb_channels=1280, + num_layers=3, + add_downsample=False, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +block_four.load_state_dict(block_four_sd_new) + + +print("MID BLOCK 1") + +mid_block_one_sd_orig = model.mid.state_dict() +mid_block_one_sd_new = {} + +for i in range(2): + mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight") + mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias") + mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight") + mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias") + mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight") + mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias") + mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight") + mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias") + mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight") + mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias") + +assert len(mid_block_one_sd_orig) == 0 + +mid_block_one = UNetMidBlock2D( + in_channels=1024, + temb_channels=1280, + num_layers=1, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, + add_attention=False, +) + +mid_block_one.load_state_dict(mid_block_one_sd_new) + +print("UP BLOCK ONE") + +up_block_one_sd_orig = model.up[-1].state_dict() +up_block_one_sd_new = {} + +for i in range(4): + up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight") + up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias") + up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight") + up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias") + up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight") + up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias") + up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight") + up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias") + up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight") + up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias") + up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight") + up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias") + +up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight") +up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias") +up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight") +up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias") +up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight") +up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias") +up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight") +up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias") +up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight") +up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias") + +assert len(up_block_one_sd_orig) == 0 + +up_block_one = ResnetUpsampleBlock2D( + in_channels=1024, + prev_output_channel=1024, + out_channels=1024, + temb_channels=1280, + num_layers=4, + add_upsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_one.load_state_dict(up_block_one_sd_new) + +print("UP BLOCK TWO") + +up_block_two_sd_orig = model.up[-2].state_dict() +up_block_two_sd_new = {} + +for i in range(4): + up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight") + up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias") + up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight") + up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias") + up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight") + up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias") + up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight") + up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias") + up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight") + up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias") + up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight") + up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias") + +up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight") +up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias") +up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight") +up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias") +up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight") +up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias") +up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight") +up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias") +up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight") +up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias") + +assert len(up_block_two_sd_orig) == 0 + +up_block_two = ResnetUpsampleBlock2D( + in_channels=640, + prev_output_channel=1024, + out_channels=1024, + temb_channels=1280, + num_layers=4, + add_upsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_two.load_state_dict(up_block_two_sd_new) + +print("UP BLOCK THREE") + +up_block_three_sd_orig = model.up[-3].state_dict() +up_block_three_sd_new = {} + +for i in range(4): + up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight") + up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias") + up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight") + up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias") + up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight") + up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias") + up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight") + up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias") + up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight") + up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias") + up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight") + up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias") + +up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight") +up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias") +up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight") +up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias") +up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight") +up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias") +up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight") +up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias") +up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight") +up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias") + +assert len(up_block_three_sd_orig) == 0 + +up_block_three = ResnetUpsampleBlock2D( + in_channels=320, + prev_output_channel=1024, + out_channels=640, + temb_channels=1280, + num_layers=4, + add_upsample=True, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_three.load_state_dict(up_block_three_sd_new) + +print("UP BLOCK FOUR") + +up_block_four_sd_orig = model.up[-4].state_dict() +up_block_four_sd_new = {} + +for i in range(4): + up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight") + up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias") + up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight") + up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias") + up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight") + up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias") + up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight") + up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias") + up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight") + up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias") + up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight") + up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias") + +assert len(up_block_four_sd_orig) == 0 + +up_block_four = ResnetUpsampleBlock2D( + in_channels=320, + prev_output_channel=640, + out_channels=320, + temb_channels=1280, + num_layers=4, + add_upsample=False, + resnet_time_scale_shift="scale_shift", + resnet_eps=1e-5, +) + +up_block_four.load_state_dict(up_block_four_sd_new) + +print("initial projection (conv_in)") + +conv_in_sd_orig = model.embed_image.state_dict() +conv_in_sd_new = {} + +conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight") +conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias") + +assert len(conv_in_sd_orig) == 0 + +block_out_channels = [320, 640, 1024, 1024] + +in_channels = 7 +conv_in_kernel = 3 +conv_in_padding = (conv_in_kernel - 1) // 2 +conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) + +conv_in.load_state_dict(conv_in_sd_new) + +print("out projection (conv_out) (conv_norm_out)") +out_channels = 6 +norm_num_groups = 32 +norm_eps = 1e-5 +act_fn = "silu" +conv_out_kernel = 3 +conv_out_padding = (conv_out_kernel - 1) // 2 +conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) +# uses torch.functional in orig +# conv_act = get_activation(act_fn) +conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding) + +conv_norm_out.load_state_dict(model.output.gn.state_dict()) +conv_out.load_state_dict(model.output.f.state_dict()) + +print("timestep projection (time_proj) (time_embedding)") + +f1_sd = model.embed_time.f_1.state_dict() +f2_sd = model.embed_time.f_2.state_dict() + +time_embedding_sd = { + "linear_1.weight": f1_sd.pop("weight"), + "linear_1.bias": f1_sd.pop("bias"), + "linear_2.weight": f2_sd.pop("weight"), + "linear_2.bias": f2_sd.pop("bias"), +} + +assert len(f1_sd) == 0 +assert len(f2_sd) == 0 + +time_embedding_type = "learned" +num_train_timesteps = 1024 +time_embedding_dim = 1280 + +time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) +timestep_input_dim = block_out_channels[0] + +time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim) + +time_proj.load_state_dict(model.embed_time.emb.state_dict()) +time_embedding.load_state_dict(time_embedding_sd) + +print("CONVERT") + +time_embedding.to("cuda") +time_proj.to("cuda") +conv_in.to("cuda") + +block_one.to("cuda") +block_two.to("cuda") +block_three.to("cuda") +block_four.to("cuda") + +mid_block_one.to("cuda") + +up_block_one.to("cuda") +up_block_two.to("cuda") +up_block_three.to("cuda") +up_block_four.to("cuda") + +conv_norm_out.to("cuda") +conv_out.to("cuda") + +model.time_proj = time_proj +model.time_embedding = time_embedding +model.embed_image = conv_in + +model.down[0] = block_one +model.down[1] = block_two +model.down[2] = block_three +model.down[3] = block_four + +model.mid = mid_block_one + +model.up[-1] = up_block_one +model.up[-2] = up_block_two +model.up[-3] = up_block_three +model.up[-4] = up_block_four + +model.output.gn = conv_norm_out +model.output.f = conv_out + +model.converted = True + +sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) +save_image(sample_consistency_new, "con_new.png") + +assert (sample_consistency_orig == sample_consistency_new).all() + +print("making unet") + +unet = UNet2DModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + up_block_types=( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + block_out_channels=block_out_channels, + layers_per_block=3, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + resnet_time_scale_shift="scale_shift", + time_embedding_type="learned", + num_train_timesteps=num_train_timesteps, + add_attention=False, +) + +unet_state_dict = {} + + +def add_state_dict(prefix, mod): + for k, v in mod.state_dict().items(): + unet_state_dict[f"{prefix}.{k}"] = v + + +add_state_dict("conv_in", conv_in) +add_state_dict("time_proj", time_proj) +add_state_dict("time_embedding", time_embedding) +add_state_dict("down_blocks.0", block_one) +add_state_dict("down_blocks.1", block_two) +add_state_dict("down_blocks.2", block_three) +add_state_dict("down_blocks.3", block_four) +add_state_dict("mid_block", mid_block_one) +add_state_dict("up_blocks.0", up_block_one) +add_state_dict("up_blocks.1", up_block_two) +add_state_dict("up_blocks.2", up_block_three) +add_state_dict("up_blocks.3", up_block_four) +add_state_dict("conv_norm_out", conv_norm_out) +add_state_dict("conv_out", conv_out) + +unet.load_state_dict(unet_state_dict) + +print("running with diffusers unet") + +unet.to("cuda") + +decoder_consistency.ckpt = unet + +sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) +save_image(sample_consistency_new_2, "con_new_2.png") + +assert (sample_consistency_orig == sample_consistency_new_2).all() + +print("running with diffusers model") + +Encoder.old_constructor = Encoder.__init__ + + +def new_constructor(self, **kwargs): + self.old_constructor(**kwargs) + self.constructor_arguments = kwargs + + +Encoder.__init__ = new_constructor + + +vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") +consistency_vae = ConsistencyDecoderVAE( + encoder_args=vae.encoder.constructor_arguments, + decoder_args=unet.config, + scaling_factor=vae.config.scaling_factor, + block_out_channels=vae.config.block_out_channels, + latent_channels=vae.config.latent_channels, +) +consistency_vae.encoder.load_state_dict(vae.encoder.state_dict()) +consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict()) +consistency_vae.decoder_unet.load_state_dict(unet.state_dict()) + +consistency_vae.to(dtype=torch.float16, device="cuda") + +sample_consistency_new_3 = consistency_vae.decode( + 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0) +).sample + +print("max difference") +print((sample_consistency_orig - sample_consistency_new_3).abs().max()) +print("total difference") +print((sample_consistency_orig - sample_consistency_new_3).abs().sum()) +# assert (sample_consistency_orig == sample_consistency_new_3).all() + +print("running with diffusers pipeline") + +pipe = DiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16 +) +pipe.to("cuda") + +pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png") + + +if args.save_pretrained is not None: + consistency_vae.save_pretrained(args.save_pretrained) diff --git a/diffusers/scripts/convert_consistency_to_diffusers.py b/diffusers/scripts/convert_consistency_to_diffusers.py new file mode 100755 index 0000000..0f8b4dd --- /dev/null +++ b/diffusers/scripts/convert_consistency_to_diffusers.py @@ -0,0 +1,315 @@ +import argparse +import os + +import torch + +from diffusers import ( + CMStochasticIterativeScheduler, + ConsistencyModelPipeline, + UNet2DModel, +) + + +TEST_UNET_CONFIG = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": 1000, + "block_out_channels": [32, 64], + "attention_head_dim": 8, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "attn_norm_num_groups": 32, + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +IMAGENET_64_UNET_CONFIG = { + "sample_size": 64, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 3, + "num_class_embeds": 1000, + "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "attn_norm_num_groups": 32, + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +LSUN_256_UNET_CONFIG = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 3, + "layers_per_block": 2, + "num_class_embeds": None, + "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "default", + "upsample_type": "resnet", + "downsample_type": "resnet", +} + +CD_SCHEDULER_CONFIG = { + "num_train_timesteps": 40, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_IMAGENET_64_SCHEDULER_CONFIG = { + "num_train_timesteps": 201, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + +CT_LSUN_256_SCHEDULER_CONFIG = { + "num_train_timesteps": 151, + "sigma_min": 0.002, + "sigma_max": 80.0, +} + + +def str2bool(v): + """ + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("boolean value expected") + + +def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False): + new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"] + new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"] + new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"] + new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"] + new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"] + new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"] + + if has_skip: + new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"] + new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"] + + return new_checkpoint + + +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None): + weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) + + new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] + new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] + + new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) + + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( + checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + ) + new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) + + return new_checkpoint + + +def con_pt_to_diffuser(checkpoint_path: str, unet_config): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] + + if unet_config["num_class_embeds"] is not None: + new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] + + down_block_types = unet_config["down_block_types"] + layers_per_block = unet_config["layers_per_block"] + attention_head_dim = unet_config["attention_head_dim"] + channels_list = unet_config["block_out_channels"] + current_layer = 1 + prev_channels = channels_list[0] + + for i, layer_type in enumerate(down_block_types): + current_channels = channels_list[i] + downsample_block_has_skip = current_channels != prev_channels + if layer_type == "ResnetDownsampleBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) + current_layer += 1 + + elif layer_type == "AttnDownBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 and downsample_block_has_skip else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) + new_prefix = f"down_blocks.{i}.attentions.{j}" + old_prefix = f"input_blocks.{current_layer}.1" + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) + current_layer += 1 + + if i != len(down_block_types) - 1: + new_prefix = f"down_blocks.{i}.downsamplers.0" + old_prefix = f"input_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + prev_channels = current_channels + + # hardcoded the mid-block for now + new_prefix = "mid_block.resnets.0" + old_prefix = "middle_block.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_prefix = "mid_block.attentions.0" + old_prefix = "middle_block.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) + new_prefix = "mid_block.resnets.1" + old_prefix = "middle_block.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + current_layer = 0 + up_block_types = unet_config["up_block_types"] + + for i, layer_type in enumerate(up_block_types): + if layer_type == "ResnetUpsampleBlock2D": + for j in range(layers_per_block + 1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + current_layer += 1 + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.1" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + elif layer_type == "AttnUpBlock2D": + for j in range(layers_per_block + 1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + new_prefix = f"up_blocks.{i}.attentions.{j}" + old_prefix = f"output_blocks.{current_layer}.1" + new_checkpoint = convert_attention( + checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim + ) + current_layer += 1 + + if i != len(up_block_types) - 1: + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + parser.add_argument( + "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model." + ) + parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.") + + args = parser.parse_args() + args.class_cond = str2bool(args.class_cond) + + ckpt_name = os.path.basename(args.unet_path) + print(f"Checkpoint: {ckpt_name}") + + # Get U-Net config + if "imagenet64" in ckpt_name: + unet_config = IMAGENET_64_UNET_CONFIG + elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + unet_config = LSUN_256_UNET_CONFIG + elif "test" in ckpt_name: + unet_config = TEST_UNET_CONFIG + else: + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") + + if not args.class_cond: + unet_config["num_class_embeds"] = None + + converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) + + image_unet = UNet2DModel(**unet_config) + image_unet.load_state_dict(converted_unet_ckpt) + + # Get scheduler config + if "cd" in ckpt_name or "test" in ckpt_name: + scheduler_config = CD_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "imagenet64" in ckpt_name: + scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG + elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)): + scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG + else: + raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.") + + cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config) + + consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler) + consistency_model.save_pretrained(args.dump_path) diff --git a/diffusers/scripts/convert_dance_diffusion_to_diffusers.py b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py new file mode 100755 index 0000000..ce69bfe --- /dev/null +++ b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +import argparse +import math +import os +from copy import deepcopy + +import requests +import torch +from audio_diffusion.models import DiffusionAttnUnet1D +from diffusion import sampling +from torch import nn + +from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel + + +MODELS_MAP = { + "gwf-440k": { + "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt", + "sample_rate": 48000, + "sample_size": 65536, + }, + "jmann-small-190k": { + "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt", + "sample_rate": 48000, + "sample_size": 65536, + }, + "jmann-large-580k": { + "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt", + "sample_rate": 48000, + "sample_size": 131072, + }, + "maestro-uncond-150k": { + "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, + "unlocked-uncond-250k": { + "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, + "honk-140k": { + "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, +} + + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + + +def get_crash_schedule(t): + sigma = torch.sin(t * math.pi / 2) ** 2 + alpha = (1 - sigma**2) ** 0.5 + return alpha_sigma_to_t(alpha, sigma) + + +class Object(object): + pass + + +class DiffusionUncond(nn.Module): + def __init__(self, global_args): + super().__init__() + + self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4) + self.diffusion_ema = deepcopy(self.diffusion) + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + +def download(model_name): + url = MODELS_MAP[model_name]["url"] + r = requests.get(url, stream=True) + + local_filename = f"./{model_name}.ckpt" + with open(local_filename, "wb") as fp: + for chunk in r.iter_content(chunk_size=8192): + fp.write(chunk) + + return local_filename + + +DOWN_NUM_TO_LAYER = { + "1": "resnets.0", + "2": "attentions.0", + "3": "resnets.1", + "4": "attentions.1", + "5": "resnets.2", + "6": "attentions.2", +} +UP_NUM_TO_LAYER = { + "8": "resnets.0", + "9": "attentions.0", + "10": "resnets.1", + "11": "attentions.1", + "12": "resnets.2", + "13": "attentions.2", +} +MID_NUM_TO_LAYER = { + "1": "resnets.0", + "2": "attentions.0", + "3": "resnets.1", + "4": "attentions.1", + "5": "resnets.2", + "6": "attentions.2", + "8": "resnets.3", + "9": "attentions.3", + "10": "resnets.4", + "11": "attentions.4", + "12": "resnets.5", + "13": "attentions.5", +} +DEPTH_0_TO_LAYER = { + "0": "resnets.0", + "1": "resnets.1", + "2": "resnets.2", + "4": "resnets.0", + "5": "resnets.1", + "6": "resnets.2", +} + +RES_CONV_MAP = { + "skip": "conv_skip", + "main.0": "conv_1", + "main.1": "group_norm_1", + "main.3": "conv_2", + "main.4": "group_norm_2", +} + +ATTN_MAP = { + "norm": "group_norm", + "qkv_proj": ["query", "key", "value"], + "out_proj": ["proj_attn"], +} + + +def convert_resconv_naming(name): + if name.startswith("skip"): + return name.replace("skip", RES_CONV_MAP["skip"]) + + # name has to be of format main.{digit} + if not name.startswith("main."): + raise ValueError(f"ResConvBlock error with {name}") + + return name.replace(name[:6], RES_CONV_MAP[name[:6]]) + + +def convert_attn_naming(name): + for key, value in ATTN_MAP.items(): + if name.startswith(key) and not isinstance(value, list): + return name.replace(key, value) + elif name.startswith(key): + return [name.replace(key, v) for v in value] + raise ValueError(f"Attn error with {name}") + + +def rename(input_string, max_depth=13): + string = input_string + + if string.split(".")[0] == "timestep_embed": + return string.replace("timestep_embed", "time_proj") + + depth = 0 + if string.startswith("net.3."): + depth += 1 + string = string[6:] + elif string.startswith("net."): + string = string[4:] + + while string.startswith("main.7."): + depth += 1 + string = string[7:] + + if string.startswith("main."): + string = string[5:] + + # mid block + if string[:2].isdigit(): + layer_num = string[:2] + string_left = string[2:] + else: + layer_num = string[0] + string_left = string[1:] + + if depth == max_depth: + new_layer = MID_NUM_TO_LAYER[layer_num] + prefix = "mid_block" + elif depth > 0 and int(layer_num) < 7: + new_layer = DOWN_NUM_TO_LAYER[layer_num] + prefix = f"down_blocks.{depth}" + elif depth > 0 and int(layer_num) > 7: + new_layer = UP_NUM_TO_LAYER[layer_num] + prefix = f"up_blocks.{max_depth - depth - 1}" + elif depth == 0: + new_layer = DEPTH_0_TO_LAYER[layer_num] + prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0" + + if not string_left.startswith("."): + raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.") + + string_left = string_left[1:] + + if "resnets" in new_layer: + string_left = convert_resconv_naming(string_left) + elif "attentions" in new_layer: + new_string_left = convert_attn_naming(string_left) + string_left = new_string_left + + if not isinstance(string_left, list): + new_string = prefix + "." + new_layer + "." + string_left + else: + new_string = [prefix + "." + new_layer + "." + s for s in string_left] + return new_string + + +def rename_orig_weights(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k.endswith("kernel"): + # up- and downsample layers, don't have trainable weights + continue + + new_k = rename(k) + + # check if we need to transform from Conv => Linear for attention + if isinstance(new_k, list): + new_state_dict = transform_conv_attns(new_state_dict, new_k, v) + else: + new_state_dict[new_k] = v + + return new_state_dict + + +def transform_conv_attns(new_state_dict, new_k, v): + if len(new_k) == 1: + if len(v.shape) == 3: + # weight + new_state_dict[new_k[0]] = v[:, :, 0] + else: + # bias + new_state_dict[new_k[0]] = v + else: + # qkv matrices + trippled_shape = v.shape[0] + single_shape = trippled_shape // 3 + for i in range(3): + if len(v.shape) == 3: + new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0] + else: + new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape] + return new_state_dict + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model_name = args.model_path.split("/")[-1].split(".")[0] + if not os.path.isfile(args.model_path): + assert ( + model_name == args.model_path + ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" + args.model_path = download(model_name) + + sample_rate = MODELS_MAP[model_name]["sample_rate"] + sample_size = MODELS_MAP[model_name]["sample_size"] + + config = Object() + config.sample_size = sample_size + config.sample_rate = sample_rate + config.latent_dim = 0 + + diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate) + diffusers_state_dict = diffusers_model.state_dict() + + orig_model = DiffusionUncond(config) + orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"]) + orig_model = orig_model.diffusion_ema.eval() + orig_model_state_dict = orig_model.state_dict() + renamed_state_dict = rename_orig_weights(orig_model_state_dict) + + renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys()) + diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys()) + + assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}" + assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" + + for key, value in renamed_state_dict.items(): + assert ( + diffusers_state_dict[key].squeeze().shape == value.squeeze().shape + ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" + if key == "time_proj.weight": + value = value.squeeze() + + diffusers_state_dict[key] = value + + diffusers_model.load_state_dict(diffusers_state_dict) + + steps = 100 + seed = 33 + + diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps) + + generator = torch.manual_seed(seed) + noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device) + + t = torch.linspace(1, 0, steps + 1, device=device)[:-1] + step_list = get_crash_schedule(t) + + pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler) + + generator = torch.manual_seed(33) + audio = pipe(num_inference_steps=steps, generator=generator).audios + + generated = sampling.iplms_sample(orig_model, noise, step_list, {}) + generated = generated.clamp(-1, 1) + + diff_sum = (generated - audio).abs().sum() + diff_max = (generated - audio).abs().max() + + if args.save: + pipe.save_pretrained(args.checkpoint_path) + + print("Diff sum", diff_sum) + print("Diff max", diff_max) + + assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/" + + print(f"Conversion for {model_name} successful!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not." + ) + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py new file mode 100755 index 0000000..4659578 --- /dev/null +++ b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py @@ -0,0 +1,431 @@ +import argparse +import json + +import torch + +from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + mapping = [] + for old_item in old_list: + new_item = old_item + new_item = new_item.replace("block.", "resnets.") + new_item = new_item.replace("conv_shorcut", "conv1") + new_item = new_item.replace("in_shortcut", "conv_shortcut") + new_item = new_item.replace("temb_proj", "time_emb_proj") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False): + mapping = [] + for old_item in old_list: + new_item = old_item + + # In `model.mid`, the layer is called `attn`. + if not in_mid: + new_item = new_item.replace("attn", "attentions") + new_item = new_item.replace(".k.", ".key.") + new_item = new_item.replace(".v.", ".value.") + new_item = new_item.replace(".q.", ".query.") + + new_item = new_item.replace("proj_out", "proj_attn") + new_item = new_item.replace("norm", "group_norm") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + if attention_paths_to_split is not None: + if config is None: + raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.") + + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze() + checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze() + checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze() + + for path in paths: + new_path = path["new"] + + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + new_path = new_path.replace("down.", "down_blocks.") + new_path = new_path.replace("up.", "up_blocks.") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + if "attentions" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]].squeeze() + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def convert_ddpm_checkpoint(checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"] + + new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"] + + new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"] + + num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer}) + down_blocks = { + layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer}) + up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + + if any("downsample" in layer for layer in down_blocks[i]): + new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[ + f"down.{i}.downsample.op.weight" + ] + new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"] + # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] + # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] + + if any("block" in layer for layer in down_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key] + for layer_id in range(num_blocks) + } + + if num_blocks > 0: + for j in range(config["layers_per_block"]): + paths = renew_resnet_paths(blocks[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint) + + if any("attn" in layer for layer in down_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key] + for layer_id in range(num_blocks) + } + + if num_attn > 0: + for j in range(config["layers_per_block"]): + paths = renew_attention_paths(attns[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) + + mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key] + mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key] + mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key] + + # Mid new 2 + paths = renew_resnet_paths(mid_block_1_layers) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}], + ) + + paths = renew_resnet_paths(mid_block_2_layers) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}], + ) + + paths = renew_attention_paths(mid_attn_1_layers, in_mid=True) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}], + ) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + + if any("upsample" in layer for layer in up_blocks[i]): + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[ + f"up.{i}.upsample.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"] + + if any("block" in layer for layer in up_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks) + } + + if num_blocks > 0: + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} + paths = renew_resnet_paths(blocks[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) + + if any("attn" in layer for layer in up_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks) + } + + if num_attn > 0: + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} + paths = renew_attention_paths(attns[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) + + new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()} + return new_checkpoint + + +def convert_vq_autoenc_checkpoint(checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"] + + new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"] + + new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"] + + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer}) + down_blocks = { + layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer}) + up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + + if any("downsample" in layer for layer in down_blocks[i]): + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[ + f"encoder.down.{i}.downsample.conv.weight" + ] + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[ + f"encoder.down.{i}.downsample.conv.bias" + ] + + if any("block" in layer for layer in down_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key] + for layer_id in range(num_blocks) + } + + if num_blocks > 0: + for j in range(config["layers_per_block"]): + paths = renew_resnet_paths(blocks[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint) + + if any("attn" in layer for layer in down_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key] + for layer_id in range(num_blocks) + } + + if num_attn > 0: + for j in range(config["layers_per_block"]): + paths = renew_attention_paths(attns[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) + + mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key] + mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key] + mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key] + + # Mid new 2 + paths = renew_resnet_paths(mid_block_1_layers) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}], + ) + + paths = renew_resnet_paths(mid_block_2_layers) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}], + ) + + paths = renew_attention_paths(mid_attn_1_layers, in_mid=True) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}], + ) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + + if any("upsample" in layer for layer in up_blocks[i]): + new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[ + f"decoder.up.{i}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[ + f"decoder.up.{i}.upsample.conv.bias" + ] + + if any("block" in layer for layer in up_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks) + } + + if num_blocks > 0: + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} + paths = renew_resnet_paths(blocks[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) + + if any("attn" in layer for layer in up_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks) + } + + if num_attn > 0: + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} + paths = renew_attention_paths(attns[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) + + new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()} + new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"] + if "quantize.embedding.weight" in checkpoint: + new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"] + new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the architecture.", + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + checkpoint = torch.load(args.checkpoint_path) + + with open(args.config_file) as f: + config = json.loads(f.read()) + + # unet case + key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()} + if "encoder" in key_prefix_set and "decoder" in key_prefix_set: + converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config) + else: + converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config) + + if "ddpm" in config: + del config["ddpm"] + + if config["_class_name"] == "VQModel": + model = VQModel(**config) + model.load_state_dict(converted_checkpoint) + model.save_pretrained(args.dump_path) + elif config["_class_name"] == "AutoencoderKL": + model = AutoencoderKL(**config) + model.load_state_dict(converted_checkpoint) + model.save_pretrained(args.dump_path) + else: + model = UNet2DModel(**config) + model.load_state_dict(converted_checkpoint) + + scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) + + pipe = DDPMPipeline(unet=model, scheduler=scheduler) + pipe.save_pretrained(args.dump_path) diff --git a/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py b/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py new file mode 100755 index 0000000..dfb3871 --- /dev/null +++ b/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py @@ -0,0 +1,56 @@ +# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format +# This means that you can input your diffusers-trained LoRAs and +# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others. + +# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy +# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file +# and run the script: +# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors +# now you can use corgy.safetensors in your WebUI of choice! + +# To train your own, here are some diffusers training scripts and utils that you can use and then convert: +# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease +# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer: +# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb +# Canonical diffusers training scripts: +# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py +# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb + +import argparse +import os + +from safetensors.torch import load_file, save_file + +from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya + + +def convert_and_save(input_lora, output_lora=None): + if output_lora is None: + base_name = os.path.splitext(input_lora)[0] + output_lora = f"{base_name}_webui.safetensors" + + diffusers_state_dict = load_file(input_lora) + peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict) + kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) + save_file(kohya_state_dict, output_lora) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.") + parser.add_argument( + "--input_lora", + type=str, + required=True, + help="Path to the input LoRA model file in the diffusers format.", + ) + parser.add_argument( + "--output_lora", + type=str, + required=False, + help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.", + ) + + args = parser.parse_args() + + convert_and_save(args.input_lora, args.output_lora) diff --git a/diffusers/scripts/convert_diffusers_to_original_sdxl.py b/diffusers/scripts/convert_diffusers_to_original_sdxl.py new file mode 100755 index 0000000..648d037 --- /dev/null +++ b/diffusers/scripts/convert_diffusers_to_original_sdxl.py @@ -0,0 +1,350 @@ +# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. +# *Only* converts the UNet, VAE, and Text Encoder. +# Does not convert optimizer state or any other thing. + +import argparse +import os.path as osp +import re + +import torch +from safetensors.torch import load_file, save_file + + +# =================# +# UNet Conversion # +# =================# + +unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + # the following are for sdxl + ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), +] + +unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), +] + +unet_conversion_map_layer = [] +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(3): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i > 0: + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(4): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i < 2: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) +unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv.")) + +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) +for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +def convert_unet_state_dict(unet_state_dict): + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + +vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), +] + +for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + +# this part accounts for mid blocks in both the encoder and the decoder +for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + # the following are for SDXL + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), +] + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + if not w.ndim == 1: + return w.reshape(*w.shape, 1, 1) + else: + return w + + +def convert_vae_state_dict(vae_state_dict): + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + return new_state_dict + + +# =========================# +# Text Encoder Conversion # +# =========================# + + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("transformer.resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "text_model.final_layer_norm."), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_openclip_text_enc_state_dict(text_enc_dict): + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_openai_text_enc_state_dict(text_enc_dict): + return text_enc_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." + ) + + args = parser.parse_args() + + assert args.model_path is not None, "Must provide a model path!" + + assert args.checkpoint_path is not None, "Must provide a checkpoint path!" + + # Path for safetensors + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") + text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + if osp.exists(text_enc_2_path): + text_enc_2_dict = load_file(text_enc_2_path, device="cpu") + else: + text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin") + text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Convert text encoder 1 + text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()} + + # Convert text encoder 2 + text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict) + text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()} + # We call the `.T.contiguous()` to match what's done in + # https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085 + text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop( + "conditioner.embedders.1.model.text_projection.weight" + ).T.contiguous() + + # Put together new checkpoint + state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict} + + if args.half: + state_dict = {k: v.half() for k, v in state_dict.items()} + + if args.use_safetensors: + save_file(state_dict, args.checkpoint_path) + else: + state_dict = {"state_dict": state_dict} + torch.save(state_dict, args.checkpoint_path) diff --git a/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py new file mode 100755 index 0000000..d1b7df0 --- /dev/null +++ b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -0,0 +1,353 @@ +# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. +# *Only* converts the UNet, VAE, and Text Encoder. +# Does not convert optimizer state or any other thing. + +import argparse +import os.path as osp +import re + +import torch +from safetensors.torch import load_file, save_file + + +# =================# +# UNet Conversion # +# =================# + +unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), +] + +unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), +] + +unet_conversion_map_layer = [] +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +def convert_unet_state_dict(unet_state_dict): + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + +vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), +] + +for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + +# this part accounts for mid blocks in both the encoder and the decoder +for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), +] + +# This is probably not the most ideal solution, but it does work. +vae_extra_conversion_map = [ + ("to_q", "q"), + ("to_k", "k"), + ("to_v", "v"), + ("to_out.0", "proj_out"), +] + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + if not w.ndim == 1: + return w.reshape(*w.shape, 1, 1) + else: + return w + + +def convert_vae_state_dict(vae_state_dict): + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + keys_to_rename = {} + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + for weight_name, real_weight_name in vae_extra_conversion_map: + if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k: + keys_to_rename[k] = k.replace(weight_name, real_weight_name) + for k, v in keys_to_rename.items(): + if k in new_state_dict: + print(f"Renaming {k} to {v}") + new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k]) + del new_state_dict[k] + return new_state_dict + + +# =========================# +# Text Encoder Conversion # +# =========================# + + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_text_enc_state_dict_v20(text_enc_dict): + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." + ) + + args = parser.parse_args() + + assert args.model_path is not None, "Must provide a model path!" + + assert args.checkpoint_path is not None, "Must provide a checkpoint path!" + + # Path for safetensors + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + if args.half: + state_dict = {k: v.half() for k, v in state_dict.items()} + + if args.use_safetensors: + save_file(state_dict, args.checkpoint_path) + else: + state_dict = {"state_dict": state_dict} + torch.save(state_dict, args.checkpoint_path) diff --git a/diffusers/scripts/convert_dit_to_diffusers.py b/diffusers/scripts/convert_dit_to_diffusers.py new file mode 100755 index 0000000..dc127f6 --- /dev/null +++ b/diffusers/scripts/convert_dit_to_diffusers.py @@ -0,0 +1,162 @@ +import argparse +import os + +import torch +from torchvision.datasets.utils import download_url + +from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel + + +pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"} + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + local_path = f"pretrained_models/{model_name}" + if not os.path.isfile(local_path): + os.makedirs("pretrained_models", exist_ok=True) + web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}" + download_url(web_path, "pretrained_models") + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +def main(args): + state_dict = download_model(pretrained_models[args.image_size]) + + state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] + state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] + state_dict.pop("x_embedder.proj.weight") + state_dict.pop("x_embedder.proj.bias") + + for depth in range(28): + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[ + "t_embedder.mlp.0.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[ + "t_embedder.mlp.0.bias" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[ + "t_embedder.mlp.2.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[ + "t_embedder.mlp.2.bias" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[ + "y_embedder.embedding_table.weight" + ] + + state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[ + f"blocks.{depth}.adaLN_modulation.1.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[ + f"blocks.{depth}.adaLN_modulation.1.bias" + ] + + q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0) + + state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q + state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias + state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k + state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias + state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias + + state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[ + f"blocks.{depth}.attn.proj.weight" + ] + state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"] + + state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"] + state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"] + state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"] + state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"] + + state_dict.pop(f"blocks.{depth}.attn.qkv.weight") + state_dict.pop(f"blocks.{depth}.attn.qkv.bias") + state_dict.pop(f"blocks.{depth}.attn.proj.weight") + state_dict.pop(f"blocks.{depth}.attn.proj.bias") + state_dict.pop(f"blocks.{depth}.mlp.fc1.weight") + state_dict.pop(f"blocks.{depth}.mlp.fc1.bias") + state_dict.pop(f"blocks.{depth}.mlp.fc2.weight") + state_dict.pop(f"blocks.{depth}.mlp.fc2.bias") + state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight") + state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias") + + state_dict.pop("t_embedder.mlp.0.weight") + state_dict.pop("t_embedder.mlp.0.bias") + state_dict.pop("t_embedder.mlp.2.weight") + state_dict.pop("t_embedder.mlp.2.bias") + state_dict.pop("y_embedder.embedding_table.weight") + + state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"] + state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"] + state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"] + state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"] + + state_dict.pop("final_layer.linear.weight") + state_dict.pop("final_layer.linear.bias") + state_dict.pop("final_layer.adaLN_modulation.1.weight") + state_dict.pop("final_layer.adaLN_modulation.1.bias") + + # DiT XL/2 + transformer = Transformer2DModel( + sample_size=args.image_size // 8, + num_layers=28, + attention_head_dim=72, + in_channels=4, + out_channels=8, + patch_size=2, + attention_bias=True, + num_attention_heads=16, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_zero", + norm_elementwise_affine=False, + ) + transformer.load_state_dict(state_dict, strict=True) + + scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="linear", + prediction_type="epsilon", + clip_sample=False, + ) + + vae = AutoencoderKL.from_pretrained(args.vae_model) + + pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler) + + if args.save: + pipeline.save_pretrained(args.checkpoint_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--image_size", + default=256, + type=int, + required=False, + help="Image size of pretrained model, either 256 or 512.", + ) + parser.add_argument( + "--vae_model", + default="stabilityai/sd-vae-ft-ema", + type=str, + required=False, + help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.", + ) + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline." + ) + + args = parser.parse_args() + main(args) diff --git a/diffusers/scripts/convert_flux_to_diffusers.py b/diffusers/scripts/convert_flux_to_diffusers.py new file mode 100755 index 0000000..05a1da2 --- /dev/null +++ b/diffusers/scripts/convert_flux_to_diffusers.py @@ -0,0 +1,303 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers import AutoencoderKL, FluxTransformer2DModel +from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint +from diffusers.utils.import_utils import is_accelerate_available + + +""" +# Transformer + +python scripts/convert_flux_to_diffusers.py \ +--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \ +--filename "flux1-schnell.sft" +--output_path "flux-schnell" \ +--transformer +""" + +""" +# VAE + +python scripts/convert_flux_to_diffusers.py \ +--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \ +--filename "ae.sft" +--output_path "flux-schnell" \ +--vae +""" + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="flux.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--vae", action="store_true") +parser.add_argument("--transformer", action="store_true") +parser.add_argument("--output_path", type=str) +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() +dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_flux_transformer_checkpoint_to_diffusers( + original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 +): + converted_state_dict = {} + + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "time_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "time_in.out_layer.bias" + ) + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( + "vector_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( + "vector_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( + "vector_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( + "vector_in.out_layer.bias" + ) + + # guidance + has_guidance = any("guidance" in k for k in original_state_dict) + if has_guidance: + converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop( + "guidance_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop( + "guidance_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop( + "guidance_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop( + "guidance_in.out_layer.bias" + ) + + # context_embedder + converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight") + converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias") + + # x_embedder + converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # norms. + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 + ) + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.weight" + ) + converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.bias" + ) + + converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.bias") + ) + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + has_guidance = any("guidance" in k for k in original_ckpt) + + if args.transformer: + num_layers = 19 + num_single_layers = 38 + inner_dim = 3072 + mlp_ratio = 4.0 + converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers( + original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio + ) + transformer = FluxTransformer2DModel(guidance_embeds=has_guidance) + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + + print( + f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}" + ) + transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") + + if args.vae: + config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae") + vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16) + + converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config) + vae.load_state_dict(converted_vae_state_dict, strict=True) + vae.to(dtype).save_pretrained(f"{args.output_path}/vae") + + +if __name__ == "__main__": + main(args) diff --git a/diffusers/scripts/convert_gligen_to_diffusers.py b/diffusers/scripts/convert_gligen_to_diffusers.py new file mode 100755 index 0000000..83c1f92 --- /dev/null +++ b/diffusers/scripts/convert_gligen_to_diffusers.py @@ -0,0 +1,581 @@ +import argparse +import re + +import torch +import yaml +from transformers import ( + CLIPProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + StableDiffusionGLIGENPipeline, + StableDiffusionGLIGENTextImagePipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + assign_to_checkpoint, + conv_attn_to_linear, + protected, + renew_attention_paths, + renew_resnet_paths, + renew_vae_attention_paths, + renew_vae_resnet_paths, + shave_segments, + textenc_conversion_map, + textenc_pattern, +) + + +def convert_open_clip_checkpoint(checkpoint): + checkpoint = checkpoint["text_encoder"] + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + if "cond_stage_model.model.text_projection" in checkpoint: + d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) + else: + d_model = 1024 + + for key in keys: + if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + continue + if key in textenc_conversion_map: + text_model_dict[textenc_conversion_map[key]] = checkpoint[key] + # if key.startswith("cond_stage_model.model.transformer."): + new_key = key[len("transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + if key != "transformer.text_model.embeddings.position_ids": + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if key == "transformer.text_model.embeddings.token_embedding.weight": + text_model_dict["text_model.embeddings.token_embedding.weight"] = checkpoint[key] + + text_model_dict.pop("text_model.embeddings.transformer.text_model.embeddings.token_embedding.weight") + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def convert_gligen_vae_checkpoint(checkpoint, config): + checkpoint = checkpoint["autoencoder"] + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for key in new_checkpoint.keys(): + if "encoder.mid_block.attentions.0" in key or "decoder.mid_block.attentions.0" in key: + if "query" in key: + new_checkpoint[key.replace("query", "to_q")] = new_checkpoint.pop(key) + if "value" in key: + new_checkpoint[key.replace("value", "to_v")] = new_checkpoint.pop(key) + if "key" in key: + new_checkpoint[key.replace("key", "to_k")] = new_checkpoint.pop(key) + if "proj_attn" in key: + new_checkpoint[key.replace("proj_attn", "to_out.0")] = new_checkpoint.pop(key) + + return new_checkpoint + + +def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + unet_state_dict = {} + checkpoint = checkpoint["model"] + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has bot EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + for key in keys: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + for key in keys: + if "position_net" in key: + new_checkpoint[key] = unet_state_dict[key] + + return new_checkpoint + + +def create_vae_config(original_config, image_size: int): + vae_params = original_config["autoencoder"]["params"]["ddconfig"] + _ = original_config["autoencoder"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + + return config + + +def create_unet_config(original_config, image_size: int, attention_type): + unet_params = original_config["model"]["params"] + vae_params = original_config["autoencoder"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + if head_dim is None: + head_dim = [5, 10, 20, 20] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": unet_params["context_dim"], + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "attention_type": attention_type, + } + + return config + + +def convert_gligen_to_diffusers( + checkpoint_path: str, + original_config_file: str, + attention_type: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: int = None, + device: str = None, +): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if "global_step" in checkpoint: + checkpoint["global_step"] + else: + print("global_step key not found in model") + + original_config = yaml.safe_load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["in_channels"] = num_in_channels + + num_train_timesteps = original_config["diffusion"]["params"]["timesteps"] + beta_start = original_config["diffusion"]["params"]["linear_start"] + beta_end = original_config["diffusion"]["params"]["linear_end"] + + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type="epsilon", + ) + + # Convert the UNet2DConditionalModel model + unet_config = create_unet_config(original_config, image_size, attention_type) + unet = UNet2DConditionModel(**unet_config) + + converted_unet_checkpoint = convert_gligen_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model + vae_config = create_vae_config(original_config, image_size) + converted_vae_checkpoint = convert_gligen_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + # Convert the text model + text_encoder = convert_open_clip_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + if attention_type == "gated-text-image": + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + + pipe = StableDiffusionGLIGENTextImagePipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + processor=processor, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + ) + elif attention_type == "gated": + pipe = StableDiffusionGLIGENPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + ) + + return pipe + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the gligen architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--attention_type", + default=None, + type=str, + required=True, + help="Type of attention ex: gated or gated-text-image", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use.") + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + + args = parser.parse_args() + + pipe = convert_gligen_to_diffusers( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + attention_type=args.attention_type, + extract_ema=args.extract_ema, + num_in_channels=args.num_in_channels, + device=args.device, + ) + + if args.half: + pipe.to(dtype=torch.float16) + + pipe.save_pretrained(args.dump_path) diff --git a/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py new file mode 100755 index 0000000..1c83836 --- /dev/null +++ b/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py @@ -0,0 +1,241 @@ +import argparse + +import torch + +from diffusers import HunyuanDiT2DControlNetModel + + +def main(args): + state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu") + + if args.load_key != "none": + try: + state_dict = state_dict[args.load_key] + except KeyError: + raise KeyError( + f"{args.load_key} not found in the checkpoint." + "Please load from the following keys:{state_dict.keys()}" + ) + device = "cuda" + + model_config = HunyuanDiT2DControlNetModel.load_config( + "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer" + ) + model_config[ + "use_style_cond_and_image_meta_size" + ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False + print(model_config) + + for key in state_dict: + print("local:", key) + + model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device) + + for key in model.state_dict(): + print("diffusers:", key) + + num_layers = 19 + for i in range(num_layers): + # attn1 + # Wkqv -> to_q, to_k, to_v + q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0) + state_dict[f"blocks.{i}.attn1.to_q.weight"] = q + state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias + state_dict[f"blocks.{i}.attn1.to_k.weight"] = k + state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias + state_dict[f"blocks.{i}.attn1.to_v.weight"] = v + state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias + state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight") + state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias") + + # q_norm, k_norm -> norm_q, norm_k + state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"] + state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"] + state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"] + state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"] + + state_dict.pop(f"blocks.{i}.attn1.q_norm.weight") + state_dict.pop(f"blocks.{i}.attn1.q_norm.bias") + state_dict.pop(f"blocks.{i}.attn1.k_norm.weight") + state_dict.pop(f"blocks.{i}.attn1.k_norm.bias") + + # out_proj -> to_out + state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"] + state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"] + state_dict.pop(f"blocks.{i}.attn1.out_proj.weight") + state_dict.pop(f"blocks.{i}.attn1.out_proj.bias") + + # attn2 + # kq_proj -> to_k, to_v + k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0) + k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0) + state_dict[f"blocks.{i}.attn2.to_k.weight"] = k + state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias + state_dict[f"blocks.{i}.attn2.to_v.weight"] = v + state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias + state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias") + + # q_proj -> to_q + state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"] + state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"] + state_dict.pop(f"blocks.{i}.attn2.q_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.q_proj.bias") + + # q_norm, k_norm -> norm_q, norm_k + state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"] + state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"] + state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"] + state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"] + + state_dict.pop(f"blocks.{i}.attn2.q_norm.weight") + state_dict.pop(f"blocks.{i}.attn2.q_norm.bias") + state_dict.pop(f"blocks.{i}.attn2.k_norm.weight") + state_dict.pop(f"blocks.{i}.attn2.k_norm.bias") + + # out_proj -> to_out + state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"] + state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"] + state_dict.pop(f"blocks.{i}.attn2.out_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.out_proj.bias") + + # switch norm 2 and norm 3 + norm2_weight = state_dict[f"blocks.{i}.norm2.weight"] + norm2_bias = state_dict[f"blocks.{i}.norm2.bias"] + state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"] + state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"] + state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight + state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias + + # norm1 -> norm1.norm + # default_modulation.1 -> norm1.linear + state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"] + state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"] + state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"] + state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"] + state_dict.pop(f"blocks.{i}.norm1.weight") + state_dict.pop(f"blocks.{i}.norm1.bias") + state_dict.pop(f"blocks.{i}.default_modulation.1.weight") + state_dict.pop(f"blocks.{i}.default_modulation.1.bias") + + # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2 + state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"] + state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"] + state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"] + state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"] + state_dict.pop(f"blocks.{i}.mlp.fc1.weight") + state_dict.pop(f"blocks.{i}.mlp.fc1.bias") + state_dict.pop(f"blocks.{i}.mlp.fc2.weight") + state_dict.pop(f"blocks.{i}.mlp.fc2.bias") + + # after_proj_list -> controlnet_blocks + state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"] + state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"] + state_dict.pop(f"after_proj_list.{i}.weight") + state_dict.pop(f"after_proj_list.{i}.bias") + + # before_proj -> input_block + state_dict["input_block.weight"] = state_dict["before_proj.weight"] + state_dict["input_block.bias"] = state_dict["before_proj.bias"] + state_dict.pop("before_proj.weight") + state_dict.pop("before_proj.bias") + + # pooler -> time_extra_emb + state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"] + state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"] + state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"] + state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"] + state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"] + state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"] + state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"] + state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"] + state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"] + state_dict.pop("pooler.k_proj.weight") + state_dict.pop("pooler.k_proj.bias") + state_dict.pop("pooler.q_proj.weight") + state_dict.pop("pooler.q_proj.bias") + state_dict.pop("pooler.v_proj.weight") + state_dict.pop("pooler.v_proj.bias") + state_dict.pop("pooler.c_proj.weight") + state_dict.pop("pooler.c_proj.bias") + state_dict.pop("pooler.positional_embedding") + + # t_embedder -> time_embedding (`TimestepEmbedding`) + state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"] + state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"] + state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"] + state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"] + + state_dict.pop("t_embedder.mlp.0.bias") + state_dict.pop("t_embedder.mlp.0.weight") + state_dict.pop("t_embedder.mlp.2.bias") + state_dict.pop("t_embedder.mlp.2.weight") + + # x_embedder -> pos_embd (`PatchEmbed`) + state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] + state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] + state_dict.pop("x_embedder.proj.weight") + state_dict.pop("x_embedder.proj.bias") + + # mlp_t5 -> text_embedder + state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"] + state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"] + state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"] + state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"] + state_dict.pop("mlp_t5.0.bias") + state_dict.pop("mlp_t5.0.weight") + state_dict.pop("mlp_t5.2.bias") + state_dict.pop("mlp_t5.2.weight") + + # extra_embedder -> extra_embedder + state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"] + state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"] + state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"] + state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"] + state_dict.pop("extra_embedder.0.bias") + state_dict.pop("extra_embedder.0.weight") + state_dict.pop("extra_embedder.2.bias") + state_dict.pop("extra_embedder.2.weight") + + # style_embedder + if model_config["use_style_cond_and_image_meta_size"]: + print(state_dict["style_embedder.weight"]) + print(state_dict["style_embedder.weight"].shape) + state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1] + state_dict.pop("style_embedder.weight") + + model.load_state_dict(state_dict) + + if args.save: + model.save_pretrained(args.output_checkpoint_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not." + ) + parser.add_argument( + "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model." + ) + parser.add_argument( + "--output_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the output converted diffusers pipeline.", + ) + parser.add_argument( + "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file" + ) + parser.add_argument( + "--use_style_cond_and_image_meta_size", + type=bool, + default=False, + help="version <= v1.1: True; version >= v1.2: False", + ) + + args = parser.parse_args() + main(args) diff --git a/diffusers/scripts/convert_hunyuandit_to_diffusers.py b/diffusers/scripts/convert_hunyuandit_to_diffusers.py new file mode 100755 index 0000000..da3af83 --- /dev/null +++ b/diffusers/scripts/convert_hunyuandit_to_diffusers.py @@ -0,0 +1,267 @@ +import argparse + +import torch + +from diffusers import HunyuanDiT2DModel + + +def main(args): + state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu") + + if args.load_key != "none": + try: + state_dict = state_dict[args.load_key] + except KeyError: + raise KeyError( + f"{args.load_key} not found in the checkpoint." + f"Please load from the following keys:{state_dict.keys()}" + ) + + device = "cuda" + model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer") + model_config[ + "use_style_cond_and_image_meta_size" + ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False + + # input_size -> sample_size, text_dim -> cross_attention_dim + for key in state_dict: + print("local:", key) + + model = HunyuanDiT2DModel.from_config(model_config).to(device) + + for key in model.state_dict(): + print("diffusers:", key) + + num_layers = 40 + for i in range(num_layers): + # attn1 + # Wkqv -> to_q, to_k, to_v + q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0) + state_dict[f"blocks.{i}.attn1.to_q.weight"] = q + state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias + state_dict[f"blocks.{i}.attn1.to_k.weight"] = k + state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias + state_dict[f"blocks.{i}.attn1.to_v.weight"] = v + state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias + state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight") + state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias") + + # q_norm, k_norm -> norm_q, norm_k + state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"] + state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"] + state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"] + state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"] + + state_dict.pop(f"blocks.{i}.attn1.q_norm.weight") + state_dict.pop(f"blocks.{i}.attn1.q_norm.bias") + state_dict.pop(f"blocks.{i}.attn1.k_norm.weight") + state_dict.pop(f"blocks.{i}.attn1.k_norm.bias") + + # out_proj -> to_out + state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"] + state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"] + state_dict.pop(f"blocks.{i}.attn1.out_proj.weight") + state_dict.pop(f"blocks.{i}.attn1.out_proj.bias") + + # attn2 + # kq_proj -> to_k, to_v + k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0) + k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0) + state_dict[f"blocks.{i}.attn2.to_k.weight"] = k + state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias + state_dict[f"blocks.{i}.attn2.to_v.weight"] = v + state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias + state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias") + + # q_proj -> to_q + state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"] + state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"] + state_dict.pop(f"blocks.{i}.attn2.q_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.q_proj.bias") + + # q_norm, k_norm -> norm_q, norm_k + state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"] + state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"] + state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"] + state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"] + + state_dict.pop(f"blocks.{i}.attn2.q_norm.weight") + state_dict.pop(f"blocks.{i}.attn2.q_norm.bias") + state_dict.pop(f"blocks.{i}.attn2.k_norm.weight") + state_dict.pop(f"blocks.{i}.attn2.k_norm.bias") + + # out_proj -> to_out + state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"] + state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"] + state_dict.pop(f"blocks.{i}.attn2.out_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.out_proj.bias") + + # switch norm 2 and norm 3 + norm2_weight = state_dict[f"blocks.{i}.norm2.weight"] + norm2_bias = state_dict[f"blocks.{i}.norm2.bias"] + state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"] + state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"] + state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight + state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias + + # norm1 -> norm1.norm + # default_modulation.1 -> norm1.linear + state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"] + state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"] + state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"] + state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"] + state_dict.pop(f"blocks.{i}.norm1.weight") + state_dict.pop(f"blocks.{i}.norm1.bias") + state_dict.pop(f"blocks.{i}.default_modulation.1.weight") + state_dict.pop(f"blocks.{i}.default_modulation.1.bias") + + # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2 + state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"] + state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"] + state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"] + state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"] + state_dict.pop(f"blocks.{i}.mlp.fc1.weight") + state_dict.pop(f"blocks.{i}.mlp.fc1.bias") + state_dict.pop(f"blocks.{i}.mlp.fc2.weight") + state_dict.pop(f"blocks.{i}.mlp.fc2.bias") + + # pooler -> time_extra_emb + state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"] + state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"] + state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"] + state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"] + state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"] + state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"] + state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"] + state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"] + state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"] + state_dict.pop("pooler.k_proj.weight") + state_dict.pop("pooler.k_proj.bias") + state_dict.pop("pooler.q_proj.weight") + state_dict.pop("pooler.q_proj.bias") + state_dict.pop("pooler.v_proj.weight") + state_dict.pop("pooler.v_proj.bias") + state_dict.pop("pooler.c_proj.weight") + state_dict.pop("pooler.c_proj.bias") + state_dict.pop("pooler.positional_embedding") + + # t_embedder -> time_embedding (`TimestepEmbedding`) + state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"] + state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"] + state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"] + state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"] + + state_dict.pop("t_embedder.mlp.0.bias") + state_dict.pop("t_embedder.mlp.0.weight") + state_dict.pop("t_embedder.mlp.2.bias") + state_dict.pop("t_embedder.mlp.2.weight") + + # x_embedder -> pos_embd (`PatchEmbed`) + state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] + state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] + state_dict.pop("x_embedder.proj.weight") + state_dict.pop("x_embedder.proj.bias") + + # mlp_t5 -> text_embedder + state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"] + state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"] + state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"] + state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"] + state_dict.pop("mlp_t5.0.bias") + state_dict.pop("mlp_t5.0.weight") + state_dict.pop("mlp_t5.2.bias") + state_dict.pop("mlp_t5.2.weight") + + # extra_embedder -> extra_embedder + state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"] + state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"] + state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"] + state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"] + state_dict.pop("extra_embedder.0.bias") + state_dict.pop("extra_embedder.0.weight") + state_dict.pop("extra_embedder.2.bias") + state_dict.pop("extra_embedder.2.weight") + + # model.final_adaLN_modulation.1 -> norm_out.linear + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"]) + state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"]) + state_dict.pop("final_layer.adaLN_modulation.1.weight") + state_dict.pop("final_layer.adaLN_modulation.1.bias") + + # final_linear -> proj_out + state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"] + state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"] + state_dict.pop("final_layer.linear.weight") + state_dict.pop("final_layer.linear.bias") + + # style_embedder + if model_config["use_style_cond_and_image_meta_size"]: + print(state_dict["style_embedder.weight"]) + print(state_dict["style_embedder.weight"].shape) + state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1] + state_dict.pop("style_embedder.weight") + + model.load_state_dict(state_dict) + + from diffusers import HunyuanDiTPipeline + + if args.use_style_cond_and_image_meta_size: + pipe = HunyuanDiTPipeline.from_pretrained( + "Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32 + ) + else: + pipe = HunyuanDiTPipeline.from_pretrained( + "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32 + ) + pipe.to("cuda") + pipe.to(dtype=torch.float32) + + if args.save: + pipe.save_pretrained(args.output_checkpoint_path) + + # ### NOTE: HunyuanDiT supports both Chinese and English inputs + prompt = "一个宇航员在骑马" + # prompt = "An astronaut riding a horse" + generator = torch.Generator(device="cuda").manual_seed(0) + image = pipe( + height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0 + ).images[0] + + image.save("img.png") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not." + ) + parser.add_argument( + "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model." + ) + parser.add_argument( + "--output_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the output converted diffusers pipeline.", + ) + parser.add_argument( + "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file" + ) + parser.add_argument( + "--use_style_cond_and_image_meta_size", + type=bool, + default=False, + help="version <= v1.1: True; version >= v1.2: False", + ) + + args = parser.parse_args() + main(args) diff --git a/diffusers/scripts/convert_i2vgen_to_diffusers.py b/diffusers/scripts/convert_i2vgen_to_diffusers.py new file mode 100755 index 0000000..b9e3ff2 --- /dev/null +++ b/diffusers/scripts/convert_i2vgen_to_diffusers.py @@ -0,0 +1,510 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the LDM checkpoints.""" + +import argparse + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers import DDIMScheduler, I2VGenXLPipeline, I2VGenXLUNet, StableDiffusionPipeline + + +CLIP_ID = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + weight = old_checkpoint[path["old"]] + names = ["proj_attn.weight"] + names_2 = ["proj_out.weight", "proj_in.weight"] + if any(k in new_path for k in names): + checkpoint[new_path] = weight[:, :, 0] + elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path: + checkpoint[new_path] = weight[:, :, 0] + else: + checkpoint[new_path] = weight + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + mapping.append({"old": old_item, "new": old_item}) + + return mapping + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + if "temopral_conv" not in old_item: + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + additional_embedding_substrings = [ + "local_image_concat", + "context_embedding", + "local_image_embedding", + "fps_embedding", + ] + for k in unet_state_dict: + if any(substring in k for substring in additional_embedding_substrings): + diffusers_key = k.replace("local_image_concat", "image_latents_proj_in").replace( + "local_image_embedding", "image_latents_context_embedding" + ) + new_checkpoint[diffusers_key] = unet_state_dict[k] + + # temporal encoder. + new_checkpoint["image_latents_temporal_encoder.norm1.weight"] = unet_state_dict[ + "local_temporal_encoder.layers.0.0.norm.weight" + ] + new_checkpoint["image_latents_temporal_encoder.norm1.bias"] = unet_state_dict[ + "local_temporal_encoder.layers.0.0.norm.bias" + ] + + # attention + qkv = unet_state_dict["local_temporal_encoder.layers.0.0.fn.to_qkv.weight"] + q, k, v = torch.chunk(qkv, 3, dim=0) + new_checkpoint["image_latents_temporal_encoder.attn1.to_q.weight"] = q + new_checkpoint["image_latents_temporal_encoder.attn1.to_k.weight"] = k + new_checkpoint["image_latents_temporal_encoder.attn1.to_v.weight"] = v + new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.weight"] = unet_state_dict[ + "local_temporal_encoder.layers.0.0.fn.to_out.0.weight" + ] + new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.bias"] = unet_state_dict[ + "local_temporal_encoder.layers.0.0.fn.to_out.0.bias" + ] + + # feedforward + new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.weight"] = unet_state_dict[ + "local_temporal_encoder.layers.0.1.net.0.0.weight" + ] + new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.bias"] = unet_state_dict[ + "local_temporal_encoder.layers.0.1.net.0.0.bias" + ] + new_checkpoint["image_latents_temporal_encoder.ff.net.2.weight"] = unet_state_dict[ + "local_temporal_encoder.layers.0.1.net.2.weight" + ] + new_checkpoint["image_latents_temporal_encoder.ff.net.2.bias"] = unet_state_dict[ + "local_temporal_encoder.layers.0.1.net.2.bias" + ] + + if "class_embed_type" in config: + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")] + paths = renew_attention_paths(first_temp_attention) + meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key] + + if f"input_blocks.{i}.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [key for key in resnets if "temopral_conv" in key] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = { + "old": f"input_blocks.{i}.0.temopral_conv", + "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = { + "old": f"input_blocks.{i}.2", + "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key] + attentions = middle_blocks[1] + temp_attentions = middle_blocks[2] + resnet_1 = middle_blocks[3] + temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key] + + resnet_0_paths = renew_resnet_paths(resnet_0) + meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"} + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0) + meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"} + assign_to_checkpoint( + temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + resnet_1_paths = renew_resnet_paths(resnet_1) + meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"} + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1) + meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"} + assign_to_checkpoint( + temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temp_attentions_paths = renew_attention_paths(temp_attentions) + meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"} + assign_to_checkpoint( + temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [key for key in resnets if "temopral_conv" in key] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = { + "old": f"output_blocks.{i}.0.temopral_conv", + "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = { + "old": f"output_blocks.{i}.2", + "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l] + for path in temopral_conv_paths: + pruned_path = path.split("temopral_conv.")[-1] + old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path]) + new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--unet_checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--push_to_hub", action="store_true") + args = parser.parse_args() + + # UNet + unet_checkpoint = torch.load(args.unet_checkpoint_path, map_location="cpu") + unet_checkpoint = unet_checkpoint["state_dict"] + unet = I2VGenXLUNet(sample_size=32) + + converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config) + + diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys()) + diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys()) + + assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match" + + unet.load_state_dict(converted_ckpt, strict=True) + + # vae + temp_pipe = StableDiffusionPipeline.from_single_file( + "https://huggingface.co/ali-vilab/i2vgen-xl/blob/main/models/v2-1_512-ema-pruned.ckpt" + ) + vae = temp_pipe.vae + del temp_pipe + + # text encoder and tokenizer + text_encoder = CLIPTextModel.from_pretrained(CLIP_ID) + tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID) + + # image encoder and feature extractor + image_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_ID) + feature_extractor = CLIPImageProcessor.from_pretrained(CLIP_ID) + + # scheduler + # https://github.com/ali-vilab/i2vgen-xl/blob/main/configs/i2vgen_xl_train.yaml + scheduler = DDIMScheduler( + beta_schedule="squaredcos_cap_v2", + rescale_betas_zero_snr=True, + set_alpha_to_one=True, + clip_sample=False, + steps_offset=1, + timestep_spacing="leading", + prediction_type="v_prediction", + ) + + # final + pipeline = I2VGenXLPipeline( + unet=unet, + vae=vae, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + ) + + pipeline.save_pretrained(args.dump_path, push_to_hub=args.push_to_hub) diff --git a/diffusers/scripts/convert_if.py b/diffusers/scripts/convert_if.py new file mode 100755 index 0000000..85c739c --- /dev/null +++ b/diffusers/scripts/convert_if.py @@ -0,0 +1,1250 @@ +import argparse +import inspect +import os + +import numpy as np +import torch +import yaml +from torch.nn import functional as F +from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer + +from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel +from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", required=False, default=None, type=str) + + parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str) + + parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str) + + parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file") + + parser.add_argument( + "--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file" + ) + + parser.add_argument( + "--unet_checkpoint_path_stage_2", + required=False, + default=None, + type=str, + help="Path to stage 2 unet checkpoint file", + ) + + parser.add_argument( + "--unet_checkpoint_path_stage_3", + required=False, + default=None, + type=str, + help="Path to stage 3 unet checkpoint file", + ) + + parser.add_argument("--p_head_path", type=str, required=True) + + parser.add_argument("--w_head_path", type=str, required=True) + + args = parser.parse_args() + + return args + + +def main(args): + tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl") + text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl") + + feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path) + + if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None: + convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args) + + if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None: + convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2) + + if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None: + convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3) + + +def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args): + unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path) + + scheduler = DDPMScheduler( + variance_type="learned_range", + beta_schedule="squaredcos_cap_v2", + prediction_type="epsilon", + thresholding=True, + dynamic_thresholding_ratio=0.95, + sample_max_value=1.5, + ) + + pipe = IFPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=True, + ) + + pipe.save_pretrained(args.dump_path) + + +def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage): + if stage == 2: + unet_checkpoint_path = args.unet_checkpoint_path_stage_2 + sample_size = None + dump_path = args.dump_path_stage_2 + elif stage == 3: + unet_checkpoint_path = args.unet_checkpoint_path_stage_3 + sample_size = 1024 + dump_path = args.dump_path_stage_3 + else: + assert False + + unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size) + + image_noising_scheduler = DDPMScheduler( + beta_schedule="squaredcos_cap_v2", + ) + + scheduler = DDPMScheduler( + variance_type="learned_range", + beta_schedule="squaredcos_cap_v2", + prediction_type="epsilon", + thresholding=True, + dynamic_thresholding_ratio=0.95, + sample_max_value=1.0, + ) + + pipe = IFSuperResolutionPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=True, + ) + + pipe.save_pretrained(dump_path) + + +def get_stage_1_unet(unet_config, unet_checkpoint_path): + original_unet_config = yaml.safe_load(unet_config) + original_unet_config = original_unet_config["params"] + + unet_diffusers_config = create_unet_diffusers_config(original_unet_config) + + unet = UNet2DConditionModel(**unet_diffusers_config) + + device = "cuda" if torch.cuda.is_available() else "cpu" + unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path + ) + + unet.load_state_dict(converted_unet_checkpoint) + + return unet + + +def convert_safety_checker(p_head_path, w_head_path): + state_dict = {} + + # p head + + p_head = np.load(p_head_path) + + p_head_weights = p_head["weights"] + p_head_weights = torch.from_numpy(p_head_weights) + p_head_weights = p_head_weights.unsqueeze(0) + + p_head_biases = p_head["biases"] + p_head_biases = torch.from_numpy(p_head_biases) + p_head_biases = p_head_biases.unsqueeze(0) + + state_dict["p_head.weight"] = p_head_weights + state_dict["p_head.bias"] = p_head_biases + + # w head + + w_head = np.load(w_head_path) + + w_head_weights = w_head["weights"] + w_head_weights = torch.from_numpy(w_head_weights) + w_head_weights = w_head_weights.unsqueeze(0) + + w_head_biases = w_head["biases"] + w_head_biases = torch.from_numpy(w_head_biases) + w_head_biases = w_head_biases.unsqueeze(0) + + state_dict["w_head.weight"] = w_head_weights + state_dict["w_head.bias"] = w_head_biases + + # vision model + + vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + vision_model_state_dict = vision_model.state_dict() + + for key, value in vision_model_state_dict.items(): + key = f"vision_model.{key}" + state_dict[key] = value + + # full model + + config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14") + safety_checker = IFSafetyChecker(config) + + safety_checker.load_state_dict(state_dict) + + return safety_checker + + +def create_unet_diffusers_config(original_unet_config, class_embed_type=None): + attention_resolutions = parse_list(original_unet_config["attention_resolutions"]) + attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions] + + channel_mult = parse_list(original_unet_config["channel_mult"]) + block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult] + + down_block_types = [] + resolution = 1 + + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnDownBlock2D" + elif original_unet_config["resblock_updown"]: + block_type = "ResnetDownsampleBlock2D" + else: + block_type = "DownBlock2D" + + down_block_types.append(block_type) + + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnUpBlock2D" + elif original_unet_config["resblock_updown"]: + block_type = "ResnetUpsampleBlock2D" + else: + block_type = "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + head_dim = original_unet_config["num_head_channels"] + + use_linear_projection = ( + original_unet_config["use_linear_in_transformer"] + if "use_linear_in_transformer" in original_unet_config + else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + projection_class_embeddings_input_dim = None + + if class_embed_type is None: + if "num_classes" in original_unet_config: + if original_unet_config["num_classes"] == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in original_unet_config + projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"] + else: + raise NotImplementedError( + f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}" + ) + + config = { + "sample_size": original_unet_config["image_size"], + "in_channels": original_unet_config["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": original_unet_config["num_res_blocks"], + "cross_attention_dim": original_unet_config["encoder_channels"], + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "out_channels": original_unet_config["out_channels"], + "up_block_types": tuple(up_block_types), + "upcast_attention": False, # TODO: guessing + "cross_attention_norm": "group_norm", + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "addition_embed_type": "text", + "act_fn": "gelu", + } + + if original_unet_config["use_scale_shift_norm"]: + config["resnet_time_scale_shift"] = "scale_shift" + + if "encoder_dim" in original_unet_config: + config["encoder_hid_dim"] = original_unet_config["encoder_dim"] + + return config + + +def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] in [None, "identity"]: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + + # TODO need better check than i in [4, 8, 12, 16] + block_type = config["down_block_types"][block_id] + if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [ + 4, + 8, + 12, + 16, + ]: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} + else: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + old_path = f"input_blocks.{i}.1" + new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(attentions) + meta_path = {"old": old_path, "new": new_path} + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + old_path = "middle_block.1" + new_path = "mid_block.attentions.0" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + # len(output_block_list) == 1 -> resnet + # len(output_block_list) == 2 -> resnet, attention + # len(output_block_list) == 3 -> resnet, attention, upscale resnet + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + old_path = f"output_blocks.{i}.1" + new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(attentions) + meta_path = { + "old": old_path, + "new": new_path, + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(output_block_list) == 3: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if "encoder_proj.weight" in unet_state_dict: + new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight") + new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias") + + if "encoder_pooling.0.weight" in unet_state_dict: + new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight") + new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias") + + new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop( + "encoder_pooling.1.positional_embedding" + ) + new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight") + new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias") + new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight") + new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias") + new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight") + new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias") + + new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight") + new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias") + + new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight") + new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias") + + return new_checkpoint + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + if "qkv" in new_item: + continue + + if "encoder_kv" in new_item: + continue + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight") + new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config): + qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight") + qkv_weight = qkv_weight[:, :, 0] + + qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias") + + is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"] + + split = 1 if is_cross_attn_only else 3 + + weights, bias = split_attentions( + weight=qkv_weight, + bias=qkv_bias, + split=split, + chunk_size=config["attention_head_dim"], + ) + + if is_cross_attn_only: + query_weight, q_bias = weights, bias + new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0] + new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0] + else: + [query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias + new_checkpoint[f"{new_path}.to_q.weight"] = query_weight + new_checkpoint[f"{new_path}.to_q.bias"] = q_bias + new_checkpoint[f"{new_path}.to_k.weight"] = key_weight + new_checkpoint[f"{new_path}.to_k.bias"] = k_bias + new_checkpoint[f"{new_path}.to_v.weight"] = value_weight + new_checkpoint[f"{new_path}.to_v.bias"] = v_bias + + encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight") + encoder_kv_weight = encoder_kv_weight[:, :, 0] + + encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias") + + [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( + weight=encoder_kv_weight, + bias=encoder_kv_bias, + split=2, + chunk_size=config["attention_head_dim"], + ) + + new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight + new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias + new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight + new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias + + +def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + for path in paths: + new_path = path["new"] + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +def parse_list(value): + if isinstance(value, str): + value = value.split(",") + value = [int(v) for v in value] + elif isinstance(value, list): + pass + else: + raise ValueError(f"Can't parse list for type: {type(value)}") + + return value + + +# below is copy and pasted from original convert_if_stage_2.py script + + +def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None): + orig_path = unet_checkpoint_path + + original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml")) + original_unet_config = original_unet_config["params"] + + unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config) + unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int( + original_unet_config["channel_mult"].split(",")[-1] + ) + if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]: + unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"] + unet_diffusers_config["class_embed_type"] = "timestep" + unet_diffusers_config["addition_embed_type"] = "text" + + unet_diffusers_config["time_embedding_act_fn"] = "gelu" + unet_diffusers_config["resnet_skip_time_act"] = True + unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071 + unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071 + unet_diffusers_config["only_cross_attention"] = ( + bool(original_unet_config["disable_self_attentions"]) + if ( + "disable_self_attentions" in original_unet_config + and isinstance(original_unet_config["disable_self_attentions"], int) + ) + else True + ) + + if sample_size is None: + unet_diffusers_config["sample_size"] = original_unet_config["image_size"] + else: + # The second upscaler unet's sample size is incorrectly specified + # in the config and is instead hardcoded in source + unet_diffusers_config["sample_size"] = sample_size + + unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu") + + if verify_param_count: + # check that architecture matches - is a bit slow + verify_param_count(orig_path, unet_diffusers_config) + + converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint( + unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path + ) + converted_keys = converted_unet_checkpoint.keys() + + model = UNet2DConditionModel(**unet_diffusers_config) + expected_weights = model.state_dict().keys() + + diff_c_e = set(converted_keys) - set(expected_weights) + diff_e_c = set(expected_weights) - set(converted_keys) + + assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}" + assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}" + + model.load_state_dict(converted_unet_checkpoint) + + return model + + +def superres_create_unet_diffusers_config(original_unet_config): + attention_resolutions = parse_list(original_unet_config["attention_resolutions"]) + attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions] + + channel_mult = parse_list(original_unet_config["channel_mult"]) + block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult] + + down_block_types = [] + resolution = 1 + + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnDownBlock2D" + elif original_unet_config["resblock_updown"]: + block_type = "ResnetDownsampleBlock2D" + else: + block_type = "DownBlock2D" + + down_block_types.append(block_type) + + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnUpBlock2D" + elif original_unet_config["resblock_updown"]: + block_type = "ResnetUpsampleBlock2D" + else: + block_type = "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + head_dim = original_unet_config["num_head_channels"] + use_linear_projection = ( + original_unet_config["use_linear_in_transformer"] + if "use_linear_in_transformer" in original_unet_config + else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + class_embed_type = None + projection_class_embeddings_input_dim = None + + if "num_classes" in original_unet_config: + if original_unet_config["num_classes"] == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in original_unet_config + projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"] + else: + raise NotImplementedError( + f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}" + ) + + config = { + "in_channels": original_unet_config["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": tuple(original_unet_config["num_res_blocks"]), + "cross_attention_dim": original_unet_config["encoder_channels"], + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "out_channels": original_unet_config["out_channels"], + "up_block_types": tuple(up_block_types), + "upcast_attention": False, # TODO: guessing + "cross_attention_norm": "group_norm", + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "act_fn": "gelu", + } + + if original_unet_config["use_scale_shift_norm"]: + config["resnet_time_scale_shift"] = "scale_shift" + + return config + + +def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if "encoder_proj.weight" in unet_state_dict: + new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"] + new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"] + + if "encoder_pooling.0.weight" in unet_state_dict: + mapping = { + "encoder_pooling.0": "add_embedding.norm1", + "encoder_pooling.1": "add_embedding.pool", + "encoder_pooling.2": "add_embedding.proj", + "encoder_pooling.3": "add_embedding.norm2", + } + for key in unet_state_dict.keys(): + if key.startswith("encoder_pooling"): + prefix = key[: len("encoder_pooling.0")] + new_key = key.replace(prefix, mapping[prefix]) + new_checkpoint[new_key] = unet_state_dict[key] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] + for layer_id in range(num_output_blocks) + } + if not isinstance(config["layers_per_block"], int): + layers_per_block_list = [e + 1 for e in config["layers_per_block"]] + layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) + downsampler_ids = layers_per_block_cumsum + else: + # TODO need better check than i in [4, 8, 12, 16] + downsampler_ids = [4, 8, 12, 16] + + for i in range(1, num_input_blocks): + if isinstance(config["layers_per_block"], int): + layers_per_block = config["layers_per_block"] + block_id = (i - 1) // (layers_per_block + 1) + layer_in_block_id = (i - 1) % (layers_per_block + 1) + else: + block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n) + passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 + layer_in_block_id = (i - 1) - passed_blocks + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + + block_type = config["down_block_types"][block_id] + if ( + block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D" + ) and i in downsampler_ids: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} + else: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + old_path = f"input_blocks.{i}.1" + new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(attentions) + meta_path = {"old": old_path, "new": new_path} + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + old_path = "middle_block.1" + new_path = "mid_block.attentions.0" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + if not isinstance(config["layers_per_block"], int): + layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]])) + layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) + + for i in range(num_output_blocks): + if isinstance(config["layers_per_block"], int): + layers_per_block = config["layers_per_block"] + block_id = i // (layers_per_block + 1) + layer_in_block_id = i % (layers_per_block + 1) + else: + block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n) + passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 + layer_in_block_id = i - passed_blocks + + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + # len(output_block_list) == 1 -> resnet + # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet + # len(output_block_list) == 3 -> resnet, attention, upscale resnet + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + + has_attention = True + if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]): + has_attention = False + + maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # this layer was no attention + has_attention = False + maybe_attentions = [] + + if has_attention: + old_path = f"output_blocks.{i}.1" + new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(maybe_attentions) + meta_path = { + "old": old_path, + "new": new_path, + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0): + layer_id = len(output_block_list) - 1 + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key] + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def verify_param_count(orig_path, unet_diffusers_config): + if "-II-" in orig_path: + from deepfloyd_if.modules import IFStageII + + if_II = IFStageII(device="cpu", dir_or_name=orig_path) + elif "-III-" in orig_path: + from deepfloyd_if.modules import IFStageIII + + if_II = IFStageIII(device="cpu", dir_or_name=orig_path) + else: + assert f"Weird name. Should have -II- or -III- in path: {orig_path}" + + unet = UNet2DConditionModel(**unet_diffusers_config) + + # in params + assert_param_count(unet.time_embedding, if_II.model.time_embed) + assert_param_count(unet.conv_in, if_II.model.input_blocks[:1]) + + # downblocks + assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4]) + assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7]) + assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11]) + + if "-II-" in orig_path: + assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17]) + assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:]) + if "-III-" in orig_path: + assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15]) + assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20]) + assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:]) + + # mid block + assert_param_count(unet.mid_block, if_II.model.middle_block) + + # up block + if "-II-" in orig_path: + assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6]) + assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12]) + assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16]) + assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19]) + assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:]) + if "-III-" in orig_path: + assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5]) + assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10]) + assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14]) + assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18]) + assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21]) + assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24]) + + # out params + assert_param_count(unet.conv_norm_out, if_II.model.out[0]) + assert_param_count(unet.conv_out, if_II.model.out[2]) + + # make sure all model architecture has same param count + assert_param_count(unet, if_II.model) + + +def assert_param_count(model_1, model_2): + count_1 = sum(p.numel() for p in model_1.parameters()) + count_2 = sum(p.numel() for p in model_2.parameters()) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def superres_check_against_original(dump_path, unet_checkpoint_path): + model_path = dump_path + model = UNet2DConditionModel.from_pretrained(model_path) + model.to("cuda") + orig_path = unet_checkpoint_path + + if "-II-" in orig_path: + from deepfloyd_if.modules import IFStageII + + if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model + elif "-III-" in orig_path: + from deepfloyd_if.modules import IFStageIII + + if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model + + batch_size = 1 + channels = model.config.in_channels // 2 + height = model.config.sample_size + width = model.config.sample_size + height = 1024 + width = 1024 + + torch.manual_seed(0) + + latents = torch.randn((batch_size, channels, height, width), device=model.device) + image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype) + t = torch.tensor([5], device=model.device).to(model.dtype) + + seq_len = 64 + encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to( + model.dtype + ) + + fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype) + + with torch.no_grad(): + out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states) + + if_II_model.to("cpu") + del if_II_model + import gc + + torch.cuda.empty_cache() + gc.collect() + print(50 * "=") + + with torch.no_grad(): + noise_pred = model( + sample=latent_model_input, + encoder_hidden_states=encoder_hidden_states, + class_labels=fake_class_labels, + timestep=t, + ).sample + + print("Out shape", noise_pred.shape) + print("Diff", (out - noise_pred).abs().sum()) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/diffusers/scripts/convert_k_upscaler_to_diffusers.py b/diffusers/scripts/convert_k_upscaler_to_diffusers.py new file mode 100755 index 0000000..62abedd --- /dev/null +++ b/diffusers/scripts/convert_k_upscaler_to_diffusers.py @@ -0,0 +1,297 @@ +import argparse + +import huggingface_hub +import k_diffusion as K +import torch + +from diffusers import UNet2DConditionModel + + +UPSCALER_REPO = "pcuenq/k-upscaler" + + +def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"], + f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"], + f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"], + } + ) + + return rv + + +def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0) + rv = { + # norm + f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"], + f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"], + # to_q + f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q, + # to_k + f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k, + # to_v + f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v, + # to_out + f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"] + .squeeze(-1) + .squeeze(-1), + f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"], + } + + return rv + + +def cross_attn_to_diffusers_checkpoint( + checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix +): + weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0) + bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0) + + rv = { + # norm2 (ada groupnorm) + f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[ + f"{attention_prefix}.norm_dec.mapper.weight" + ], + f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[ + f"{attention_prefix}.norm_dec.mapper.bias" + ], + # layernorm on encoder_hidden_state + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[ + f"{attention_prefix}.norm_enc.weight" + ], + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[ + f"{attention_prefix}.norm_enc.bias" + ], + # to_q + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[ + f"{attention_prefix}.q_proj.weight" + ] + .squeeze(-1) + .squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[ + f"{attention_prefix}.q_proj.bias" + ], + # to_k + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k, + # to_v + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v, + # to_out + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[ + f"{attention_prefix}.out_proj.weight" + ] + .squeeze(-1) + .squeeze(-1), + f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[ + f"{attention_prefix}.out_proj.bias" + ], + } + + return rv + + +def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type): + block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks" + block_prefix = f"{block_prefix}.{block_idx}" + + diffusers_checkpoint = {} + + if not hasattr(block, "attentions"): + n = 1 # resnet only + elif not block.attentions[0].add_self_attention: + n = 2 # resnet -> cross-attention + else: + n = 3 # resnet -> self-attention -> cross-attention) + + for resnet_idx, resnet in enumerate(block.resnets): + # diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}" + idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1 + resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + if hasattr(block, "attentions"): + for attention_idx, attention in enumerate(block.attentions): + diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}" + idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2 + self_attention_prefix = f"{block_prefix}.{idx}" + cross_attention_prefix = f"{block_prefix}.{idx }" + cross_attention_index = 1 if not attention.add_self_attention else 2 + idx = ( + n * attention_idx + cross_attention_index + if block_type == "up" + else n * attention_idx + cross_attention_index + 1 + ) + cross_attention_prefix = f"{block_prefix}.{idx }" + + diffusers_checkpoint.update( + cross_attn_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + diffusers_attention_index=2, + attention_prefix=cross_attention_prefix, + ) + ) + + if attention.add_self_attention is True: + diffusers_checkpoint.update( + self_attn_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=self_attention_prefix, + ) + ) + + return diffusers_checkpoint + + +def unet_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # pre-processing + diffusers_checkpoint.update( + { + "conv_in.weight": checkpoint["inner_model.proj_in.weight"], + "conv_in.bias": checkpoint["inner_model.proj_in.bias"], + } + ) + + # timestep and class embedding + diffusers_checkpoint.update( + { + "time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1), + "time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"], + "time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"], + "time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"], + "time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"], + "time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.down_blocks): + diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down")) + + # up_blocks + for up_block_idx, up_block in enumerate(model.up_blocks): + diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up")) + + # post-processing + diffusers_checkpoint.update( + { + "conv_out.weight": checkpoint["inner_model.proj_out.weight"], + "conv_out.bias": checkpoint["inner_model.proj_out.bias"], + } + ) + + return diffusers_checkpoint + + +def unet_model_from_original_config(original_config): + in_channels = original_config["input_channels"] + original_config["unet_cond_dim"] + out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0) + + block_out_channels = original_config["channels"] + + assert ( + len(set(original_config["depths"])) == 1 + ), "UNet2DConditionModel currently do not support blocks with different number of layers" + layers_per_block = original_config["depths"][0] + + class_labels_dim = original_config["mapping_cond_dim"] + cross_attention_dim = original_config["cross_cond_dim"] + + attn1_types = [] + attn2_types = [] + for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]): + if s: + a1 = "self" + a2 = "cross" if c else None + elif c: + a1 = "cross" + a2 = None + else: + a1 = None + a2 = None + attn1_types.append(a1) + attn2_types.append(a2) + + unet = UNet2DConditionModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"), + mid_block_type=None, + up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"), + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn="gelu", + norm_num_groups=None, + cross_attention_dim=cross_attention_dim, + attention_head_dim=64, + time_cond_proj_dim=class_labels_dim, + resnet_time_scale_shift="scale_shift", + time_embedding_type="fourier", + timestep_post_act="gelu", + conv_in_kernel=1, + conv_out_kernel=1, + ) + + return unet + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json") + orig_weights_path = huggingface_hub.hf_hub_download( + UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth" + ) + print(f"loading original model configuration from {orig_config_path}") + print(f"loading original model checkpoint from {orig_weights_path}") + + print("converting to diffusers unet") + orig_config = K.config.load_config(open(orig_config_path))["model"] + model = unet_model_from_original_config(orig_config) + + orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"] + converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint) + + model.load_state_dict(converted_checkpoint, strict=True) + model.save_pretrained(args.dump_path) + print(f"saving converted unet model in {args.dump_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py new file mode 100755 index 0000000..5135eae --- /dev/null +++ b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py @@ -0,0 +1,1159 @@ +import argparse +import tempfile + +import torch +from accelerate import load_checkpoint_and_dispatch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel +from diffusers.models.transformers.prior_transformer import PriorTransformer +from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel +from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler + + +r""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th +``` + +Convert the model: +```sh +$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \ + --decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \ + --super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \ + --prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \ + --clip_stat_path ./ViT-L-14_stats.th \ + --dump_path +``` +""" + + +# prior + +PRIOR_ORIGINAL_PREFIX = "model" + +# Uses default arguments +PRIOR_CONFIG = {} + + +def prior_model_from_original_config(): + model = PriorTransformer(**PRIOR_CONFIG) + + return model + + +def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint): + diffusers_checkpoint = {} + + # .time_embed.0 -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"], + } + ) + + # .clip_img_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"], + } + ) + + # .text_emb_proj -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"], + } + ) + + # .text_enc_proj -> .encoder_hidden_states_proj + diffusers_checkpoint.update( + { + "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"], + "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"], + } + ) + + # .positional_embedding -> .positional_embedding + diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]}) + + # .prd_emb -> .prd_embedding + diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]}) + + # .time_embed.2 -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"], + } + ) + + # .resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .final_ln -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"], + } + ) + + # .out_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"], + } + ) + + # clip stats + clip_mean, clip_std = clip_stats_checkpoint + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std}) + + return diffusers_checkpoint + + +def prior_attention_to_diffusers( + checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim +): + diffusers_checkpoint = {} + + # .c_qkv -> .{to_q, to_k, to_v} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], + bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], + split=3, + chunk_size=attention_head_dim, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .c_proj -> .to_out.0 + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): + diffusers_checkpoint = { + # .c_fc -> .net.0.proj + f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], + f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], + # .c_proj -> .net.2 + f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], + f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], + } + + return diffusers_checkpoint + + +# done prior + + +# decoder + +DECODER_ORIGINAL_PREFIX = "model" + +# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can +# update then. +DECODER_CONFIG = { + "sample_size": 64, + "layers_per_block": 3, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + ), + "up_block_types": ( + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "ResnetUpsampleBlock2D", + ), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (320, 640, 960, 1280), + "in_channels": 3, + "out_channels": 6, + "cross_attention_dim": 1536, + "class_embed_type": "identity", + "attention_head_dim": 64, + "resnet_time_scale_shift": "scale_shift", +} + + +def decoder_model_from_original_config(): + model = UNet2DConditionModel(**DECODER_CONFIG) + + return model + + +def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + original_unet_prefix = DECODER_ORIGINAL_PREFIX + num_head_channels = DECODER_CONFIG["attention_head_dim"] + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=num_head_channels, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .input_blocks -> .down_blocks + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + original_unet_prefix=original_unet_prefix, + num_head_channels=num_head_channels, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=num_head_channels, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix)) + + return diffusers_checkpoint + + +# done decoder + +# text proj + + +def text_proj_from_original_config(): + # From the conditional unet constructor where the dimension of the projected time embeddings is + # constructed + time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4 + + cross_attention_dim = DECODER_CONFIG["cross_attention_dim"] + + model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim) + + return model + + +# Note that the input checkpoint is the original decoder checkpoint +def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint): + diffusers_checkpoint = { + # .text_seq_proj.0 -> .encoder_hidden_states_proj + "encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"], + "encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"], + # .text_seq_proj.1 -> .text_encoder_hidden_states_norm + "text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"], + "text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"], + # .clip_tok_proj -> .clip_extra_context_tokens_proj + "clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"], + "clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"], + # .text_feat_proj -> .embedding_proj + "embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"], + "embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"], + # .cf_param -> .learned_classifier_free_guidance_embeddings + "learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"], + # .clip_emb -> .clip_image_embeddings_project_to_time_embeddings + "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[ + f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight" + ], + "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[ + f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias" + ], + } + + return diffusers_checkpoint + + +# done text proj + +# super res unet first steps + +SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps" + +SUPER_RES_UNET_FIRST_STEPS_CONFIG = { + "sample_size": 256, + "layers_per_block": 3, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + "up_block_types": ( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + "block_out_channels": (320, 640, 960, 1280), + "in_channels": 6, + "out_channels": 3, + "add_attention": False, +} + + +def super_res_unet_first_steps_model_from_original_config(): + model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG) + + return model + + +def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix)) + + return diffusers_checkpoint + + +# done super res unet first steps + +# super res unet last step + +SUPER_RES_UNET_LAST_STEP_PREFIX = "model_last_step" + +SUPER_RES_UNET_LAST_STEP_CONFIG = { + "sample_size": 256, + "layers_per_block": 3, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + "up_block_types": ( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + "block_out_channels": (320, 640, 960, 1280), + "in_channels": 6, + "out_channels": 3, + "add_attention": False, +} + + +def super_res_unet_last_step_model_from_original_config(): + model = UNet2DModel(**SUPER_RES_UNET_LAST_STEP_CONFIG) + + return model + + +def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + original_unet_prefix = SUPER_RES_UNET_LAST_STEP_PREFIX + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix)) + + return diffusers_checkpoint + + +# done super res unet last step + + +# unet utils + + +# .time_embed -> .time_embedding +def unet_time_embeddings(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{original_unet_prefix}.time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{original_unet_prefix}.time_embed.0.bias"], + "time_embedding.linear_2.weight": checkpoint[f"{original_unet_prefix}.time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{original_unet_prefix}.time_embed.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks.0 -> .conv_in +def unet_conv_in(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_in.weight": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.weight"], + "conv_in.bias": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.0 -> .conv_norm_out +def unet_conv_norm_out(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_norm_out.weight": checkpoint[f"{original_unet_prefix}.out.0.weight"], + "conv_norm_out.bias": checkpoint[f"{original_unet_prefix}.out.0.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.2 -> .conv_out +def unet_conv_out(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_out.weight": checkpoint[f"{original_unet_prefix}.out.2.weight"], + "conv_out.bias": checkpoint[f"{original_unet_prefix}.out.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks -> .down_blocks +def unet_downblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, original_unet_prefix, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets" + original_down_block_prefix = f"{original_unet_prefix}.input_blocks" + + down_block = model.down_blocks[diffusers_down_block_idx] + + num_resnets = len(down_block.resnets) + + if down_block.downsamplers is None: + downsampler = False + else: + assert len(down_block.downsamplers) == 1 + downsampler = True + # The downsample block is also a resnet + num_resnets += 1 + + for resnet_idx_inc in range(num_resnets): + full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0" + + if downsampler and resnet_idx_inc == num_resnets - 1: + # this is a downsample block + full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0" + else: + # this is a regular resnet block + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if hasattr(down_block, "attentions"): + num_attentions = len(down_block.attentions) + diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +# .middle_block -> .mid_block +def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, original_unet_prefix, num_head_channels): + diffusers_checkpoint = {} + + # block 0 + + original_block_idx = 0 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.0", + resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}", + ) + ) + + original_block_idx += 1 + + # optional block 1 + + if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None: + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix="mid_block.attentions.0", + attention_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}", + num_head_channels=num_head_channels, + ) + ) + original_block_idx += 1 + + # block 1 or block 2 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.1", + resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}", + ) + ) + + return diffusers_checkpoint + + +# .output_blocks -> .up_blocks +def unet_upblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, original_unet_prefix, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets" + original_up_block_prefix = f"{original_unet_prefix}.output_blocks" + + up_block = model.up_blocks[diffusers_up_block_idx] + + num_resnets = len(up_block.resnets) + + if up_block.upsamplers is None: + upsampler = False + else: + assert len(up_block.upsamplers) == 1 + upsampler = True + # The upsample block is also a resnet + num_resnets += 1 + + has_attentions = hasattr(up_block, "attentions") + + for resnet_idx_inc in range(num_resnets): + if upsampler and resnet_idx_inc == num_resnets - 1: + # this is an upsample block + if has_attentions: + # There is a middle attention block that we skip + original_resnet_block_idx = 2 + else: + original_resnet_block_idx = 1 + + # we add the `minus 1` because the last two resnets are stuck together in the same output block + full_resnet_prefix = ( + f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}" + ) + + full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0" + else: + # this is a regular resnet block + full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0" + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if has_attentions: + num_attentions = len(up_block.attentions) + diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + diffusers_checkpoint = { + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"], + f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"], + f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"], + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"], + } + + skip_connection_prefix = f"{resnet_prefix}.skip_connection" + + if f"{skip_connection_prefix}.weight" in checkpoint: + diffusers_checkpoint.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels): + diffusers_checkpoint = {} + + # .norm -> .group_norm + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + } + ) + + # .qkv -> .{query, key, value} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.qkv.bias"], + split=3, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .encoder_kv -> .{context_key, context_value} + [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"], + split=2, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight, + f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias, + f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight, + f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias, + } + ) + + # .proj_out (1d conv) -> .proj_attn (linear) + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0 + ], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + ) + + return diffusers_checkpoint + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + assert weights[weights_biases_idx] is None + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +# done unet utils + + +# Driver functions + + +def text_encoder(): + print("loading CLIP text encoder") + + clip_name = "openai/clip-vit-large-patch14" + + # sets pad_value to 0 + pad_token = "!" + + tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") + + assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 + + text_encoder_model = CLIPTextModelWithProjection.from_pretrained( + clip_name, + # `CLIPTextModel` does not support device_map="auto" + # device_map="auto" + ) + + print("done loading CLIP text encoder") + + return text_encoder_model, tokenizer_model + + +def prior(*, args, checkpoint_map_location): + print("loading prior") + + prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) + prior_checkpoint = prior_checkpoint["state_dict"] + + clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location) + + prior_model = prior_model_from_original_config() + + prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint( + prior_model, prior_checkpoint, clip_stats_checkpoint + ) + + del prior_checkpoint + del clip_stats_checkpoint + + load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True) + + print("done loading prior") + + return prior_model + + +def decoder(*, args, checkpoint_map_location): + print("loading decoder") + + decoder_checkpoint = torch.load(args.decoder_checkpoint_path, map_location=checkpoint_map_location) + decoder_checkpoint = decoder_checkpoint["state_dict"] + + decoder_model = decoder_model_from_original_config() + + decoder_diffusers_checkpoint = decoder_original_checkpoint_to_diffusers_checkpoint( + decoder_model, decoder_checkpoint + ) + + # text proj interlude + + # The original decoder implementation includes a set of parameters that are used + # for creating the `encoder_hidden_states` which are what the U-net is conditioned + # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull + # the parameters into the UnCLIPTextProjModel class + text_proj_model = text_proj_from_original_config() + + text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(decoder_checkpoint) + + load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) + + # done text proj interlude + + del decoder_checkpoint + + load_checkpoint_to_model(decoder_diffusers_checkpoint, decoder_model, strict=True) + + print("done loading decoder") + + return decoder_model, text_proj_model + + +def super_res_unet(*, args, checkpoint_map_location): + print("loading super resolution unet") + + super_res_checkpoint = torch.load(args.super_res_unet_checkpoint_path, map_location=checkpoint_map_location) + super_res_checkpoint = super_res_checkpoint["state_dict"] + + # model_first_steps + + super_res_first_model = super_res_unet_first_steps_model_from_original_config() + + super_res_first_steps_checkpoint = super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint( + super_res_first_model, super_res_checkpoint + ) + + # model_last_step + super_res_last_model = super_res_unet_last_step_model_from_original_config() + + super_res_last_step_checkpoint = super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint( + super_res_last_model, super_res_checkpoint + ) + + del super_res_checkpoint + + load_checkpoint_to_model(super_res_first_steps_checkpoint, super_res_first_model, strict=True) + + load_checkpoint_to_model(super_res_last_step_checkpoint, super_res_last_model, strict=True) + + print("done loading super resolution unet") + + return super_res_first_model, super_res_last_model + + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile() as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--prior_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the prior checkpoint to convert.", + ) + + parser.add_argument( + "--decoder_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the decoder checkpoint to convert.", + ) + + parser.add_argument( + "--super_res_unet_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the super resolution checkpoint to convert.", + ) + + parser.add_argument( + "--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert." + ) + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + if args.debug is not None: + print(f"debug: only executing {args.debug}") + + if args.debug is None: + text_encoder_model, tokenizer_model = text_encoder() + + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + + decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location) + + super_res_first_model, super_res_last_model = super_res_unet( + args=args, checkpoint_map_location=checkpoint_map_location + ) + + prior_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample_range=5.0, + ) + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + print(f"saving Kakao Brain unCLIP to {args.dump_path}") + + pipe = UnCLIPPipeline( + prior=prior_model, + decoder=decoder_model, + text_proj=text_proj_model, + tokenizer=tokenizer_model, + text_encoder=text_encoder_model, + super_res_first=super_res_first_model, + super_res_last=super_res_last_model, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe.save_pretrained(args.dump_path) + + print("done writing Kakao Brain unCLIP") + elif args.debug == "text_encoder": + text_encoder_model, tokenizer_model = text_encoder() + elif args.debug == "prior": + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + elif args.debug == "decoder": + decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location) + elif args.debug == "super_res_unet": + super_res_first_model, super_res_last_model = super_res_unet( + args=args, checkpoint_map_location=checkpoint_map_location + ) + else: + raise ValueError(f"unknown debug value : {args.debug}") diff --git a/diffusers/scripts/convert_kandinsky3_unet.py b/diffusers/scripts/convert_kandinsky3_unet.py new file mode 100755 index 0000000..4fe8c54 --- /dev/null +++ b/diffusers/scripts/convert_kandinsky3_unet.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +import argparse +import fnmatch + +from safetensors.torch import load_file + +from diffusers import Kandinsky3UNet + + +MAPPING = { + "to_time_embed.1": "time_embedding.linear_1", + "to_time_embed.3": "time_embedding.linear_2", + "in_layer": "conv_in", + "out_layer.0": "conv_norm_out", + "out_layer.2": "conv_out", + "down_samples": "down_blocks", + "up_samples": "up_blocks", + "projection_lin": "encoder_hid_proj.projection_linear", + "projection_ln": "encoder_hid_proj.projection_norm", + "feature_pooling": "add_time_condition", + "to_query": "to_q", + "to_key": "to_k", + "to_value": "to_v", + "output_layer": "to_out.0", + "self_attention_block": "attentions.0", +} + +DYNAMIC_MAP = { + "resnet_attn_blocks.*.0": "resnets_in.*", + "resnet_attn_blocks.*.1": ("attentions.*", 1), + "resnet_attn_blocks.*.2": "resnets_out.*", +} +# MAPPING = {} + + +def convert_state_dict(unet_state_dict): + """ + Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model. + Args: + unet_model (torch.nn.Module): The original U-Net model. + unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with. + + Returns: + OrderedDict: The converted state dictionary. + """ + # Example of renaming logic (this will vary based on your model's architecture) + converted_state_dict = {} + for key in unet_state_dict: + new_key = key + for pattern, new_pattern in MAPPING.items(): + new_key = new_key.replace(pattern, new_pattern) + + for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items(): + has_matched = False + if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched: + star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1]) + + if isinstance(dyn_new_pattern, tuple): + new_star = star + dyn_new_pattern[-1] + dyn_new_pattern = dyn_new_pattern[0] + else: + new_star = star + + pattern = dyn_pattern.replace("*", str(star)) + new_pattern = dyn_new_pattern.replace("*", str(new_star)) + + new_key = new_key.replace(pattern, new_pattern) + has_matched = True + + converted_state_dict[new_key] = unet_state_dict[key] + + return converted_state_dict + + +def main(model_path, output_path): + # Load your original U-Net model + unet_state_dict = load_file(model_path) + + # Initialize your Kandinsky3UNet model + config = {} + + # Convert the state dict + converted_state_dict = convert_state_dict(unet_state_dict) + + unet = Kandinsky3UNet(config) + unet.load_state_dict(converted_state_dict) + + unet.save_pretrained(output_path) + print(f"Converted model saved to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format") + parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model") + parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") + + args = parser.parse_args() + main(args.model_path, args.output_path) diff --git a/diffusers/scripts/convert_kandinsky_to_diffusers.py b/diffusers/scripts/convert_kandinsky_to_diffusers.py new file mode 100755 index 0000000..8d3f7b6 --- /dev/null +++ b/diffusers/scripts/convert_kandinsky_to_diffusers.py @@ -0,0 +1,1411 @@ +import argparse +import os +import tempfile + +import torch +from accelerate import load_checkpoint_and_dispatch + +from diffusers import UNet2DConditionModel +from diffusers.models.transformers.prior_transformer import PriorTransformer +from diffusers.models.vq_model import VQModel + + +""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget https://huggingface.co/ai-forever/Kandinsky_2.1/blob/main/prior_fp16.ckpt +``` + +Convert the model: +```sh +python scripts/convert_kandinsky_to_diffusers.py \ + --prior_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/prior_fp16.ckpt \ + --clip_stat_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/ViT-L-14_stats.th \ + --text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/decoder_fp16.ckpt \ + --inpaint_text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/inpainting_fp16.ckpt \ + --movq_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/movq_final.ckpt \ + --dump_path /home/yiyi_huggingface_co/dump \ + --debug decoder +``` +""" + + +# prior + +PRIOR_ORIGINAL_PREFIX = "model" + +# Uses default arguments +PRIOR_CONFIG = {} + + +def prior_model_from_original_config(): + model = PriorTransformer(**PRIOR_CONFIG) + + return model + + +def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint): + diffusers_checkpoint = {} + + # .time_embed.0 -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"], + } + ) + + # .clip_img_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"], + } + ) + + # .text_emb_proj -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"], + } + ) + + # .text_enc_proj -> .encoder_hidden_states_proj + diffusers_checkpoint.update( + { + "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"], + "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"], + } + ) + + # .positional_embedding -> .positional_embedding + diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]}) + + # .prd_emb -> .prd_embedding + diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]}) + + # .time_embed.2 -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"], + } + ) + + # .resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .final_ln -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"], + } + ) + + # .out_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"], + } + ) + + # clip stats + clip_mean, clip_std = clip_stats_checkpoint + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std}) + + return diffusers_checkpoint + + +def prior_attention_to_diffusers( + checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim +): + diffusers_checkpoint = {} + + # .c_qkv -> .{to_q, to_k, to_v} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], + bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], + split=3, + chunk_size=attention_head_dim, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .c_proj -> .to_out.0 + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): + diffusers_checkpoint = { + # .c_fc -> .net.0.proj + f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], + f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], + # .c_proj -> .net.2 + f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], + f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], + } + + return diffusers_checkpoint + + +# done prior + +# unet + +# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can +# update then. + +UNET_CONFIG = { + "act_fn": "silu", + "addition_embed_type": "text_image", + "addition_embed_type_num_heads": 64, + "attention_head_dim": 64, + "block_out_channels": [384, 768, 1152, 1536], + "center_input_sample": False, + "class_embed_type": None, + "class_embeddings_concat": False, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 768, + "cross_attention_norm": None, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + ], + "downsample_padding": 1, + "dual_cross_attention": False, + "encoder_hid_dim": 1024, + "encoder_hid_dim_type": "text_image_proj", + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 3, + "mid_block_only_cross_attention": None, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 8, + "projection_class_embeddings_input_dim": None, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, + "resnet_time_scale_shift": "scale_shift", + "sample_size": 64, + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "up_block_types": [ + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "upcast_attention": False, + "use_linear_projection": False, +} + + +def unet_model_from_original_config(): + model = UNet2DConditionModel(**UNET_CONFIG) + + return model + + +def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + num_head_channels = UNET_CONFIG["attention_head_dim"] + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) + diffusers_checkpoint.update(unet_conv_in(checkpoint)) + diffusers_checkpoint.update(unet_add_embedding(checkpoint)) + diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + num_head_channels=num_head_channels, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .input_blocks -> .down_blocks + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + num_head_channels=num_head_channels, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + num_head_channels=num_head_channels, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint)) + diffusers_checkpoint.update(unet_conv_out(checkpoint)) + + return diffusers_checkpoint + + +# done unet + +# inpaint unet + +# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can +# update then. + +INPAINT_UNET_CONFIG = { + "act_fn": "silu", + "addition_embed_type": "text_image", + "addition_embed_type_num_heads": 64, + "attention_head_dim": 64, + "block_out_channels": [384, 768, 1152, 1536], + "center_input_sample": False, + "class_embed_type": None, + "class_embeddings_concat": None, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 768, + "cross_attention_norm": None, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + ], + "downsample_padding": 1, + "dual_cross_attention": False, + "encoder_hid_dim": 1024, + "encoder_hid_dim_type": "text_image_proj", + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 9, + "layers_per_block": 3, + "mid_block_only_cross_attention": None, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 8, + "projection_class_embeddings_input_dim": None, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, + "resnet_time_scale_shift": "scale_shift", + "sample_size": 64, + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "up_block_types": [ + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "ResnetUpsampleBlock2D", + ], + "upcast_attention": False, + "use_linear_projection": False, +} + + +def inpaint_unet_model_from_original_config(): + model = UNet2DConditionModel(**INPAINT_UNET_CONFIG) + + return model + + +def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"] + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) + diffusers_checkpoint.update(unet_conv_in(checkpoint)) + diffusers_checkpoint.update(unet_add_embedding(checkpoint)) + diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + num_head_channels=num_head_channels, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .input_blocks -> .down_blocks + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + num_head_channels=num_head_channels, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + num_head_channels=num_head_channels, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint)) + diffusers_checkpoint.update(unet_conv_out(checkpoint)) + + return diffusers_checkpoint + + +# done inpaint unet + + +# unet utils + + +# .time_embed -> .time_embedding +def unet_time_embeddings(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint["time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint["time_embed.0.bias"], + "time_embedding.linear_2.weight": checkpoint["time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint["time_embed.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks.0 -> .conv_in +def unet_conv_in(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_in.weight": checkpoint["input_blocks.0.0.weight"], + "conv_in.bias": checkpoint["input_blocks.0.0.bias"], + } + ) + + return diffusers_checkpoint + + +def unet_add_embedding(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"], + "add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"], + "add_embedding.text_proj.weight": checkpoint["proj_n.weight"], + "add_embedding.text_proj.bias": checkpoint["proj_n.bias"], + "add_embedding.image_proj.weight": checkpoint["img_layer.weight"], + "add_embedding.image_proj.bias": checkpoint["img_layer.bias"], + } + ) + + return diffusers_checkpoint + + +def unet_encoder_hid_proj(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"], + "encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"], + "encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"], + "encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.0 -> .conv_norm_out +def unet_conv_norm_out(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_norm_out.weight": checkpoint["out.0.weight"], + "conv_norm_out.bias": checkpoint["out.0.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.2 -> .conv_out +def unet_conv_out(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_out.weight": checkpoint["out.2.weight"], + "conv_out.bias": checkpoint["out.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks -> .down_blocks +def unet_downblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets" + original_down_block_prefix = "input_blocks" + + down_block = model.down_blocks[diffusers_down_block_idx] + + num_resnets = len(down_block.resnets) + + if down_block.downsamplers is None: + downsampler = False + else: + assert len(down_block.downsamplers) == 1 + downsampler = True + # The downsample block is also a resnet + num_resnets += 1 + + for resnet_idx_inc in range(num_resnets): + full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0" + + if downsampler and resnet_idx_inc == num_resnets - 1: + # this is a downsample block + full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0" + else: + # this is a regular resnet block + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if hasattr(down_block, "attentions"): + num_attentions = len(down_block.attentions) + diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +# .middle_block -> .mid_block +def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, num_head_channels): + diffusers_checkpoint = {} + + # block 0 + + original_block_idx = 0 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.0", + resnet_prefix=f"middle_block.{original_block_idx}", + ) + ) + + original_block_idx += 1 + + # optional block 1 + + if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None: + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix="mid_block.attentions.0", + attention_prefix=f"middle_block.{original_block_idx}", + num_head_channels=num_head_channels, + ) + ) + original_block_idx += 1 + + # block 1 or block 2 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.1", + resnet_prefix=f"middle_block.{original_block_idx}", + ) + ) + + return diffusers_checkpoint + + +# .output_blocks -> .up_blocks +def unet_upblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets" + original_up_block_prefix = "output_blocks" + + up_block = model.up_blocks[diffusers_up_block_idx] + + num_resnets = len(up_block.resnets) + + if up_block.upsamplers is None: + upsampler = False + else: + assert len(up_block.upsamplers) == 1 + upsampler = True + # The upsample block is also a resnet + num_resnets += 1 + + has_attentions = hasattr(up_block, "attentions") + + for resnet_idx_inc in range(num_resnets): + if upsampler and resnet_idx_inc == num_resnets - 1: + # this is an upsample block + if has_attentions: + # There is a middle attention block that we skip + original_resnet_block_idx = 2 + else: + original_resnet_block_idx = 1 + + # we add the `minus 1` because the last two resnets are stuck together in the same output block + full_resnet_prefix = ( + f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}" + ) + + full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0" + else: + # this is a regular resnet block + full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0" + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if has_attentions: + num_attentions = len(up_block.attentions) + diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + diffusers_checkpoint = { + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"], + f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"], + f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"], + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"], + } + + skip_connection_prefix = f"{resnet_prefix}.skip_connection" + + if f"{skip_connection_prefix}.weight" in checkpoint: + diffusers_checkpoint.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels): + diffusers_checkpoint = {} + + # .norm -> .group_norm + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + } + ) + + # .qkv -> .{query, key, value} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.qkv.bias"], + split=3, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .encoder_kv -> .{context_key, context_value} + [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"], + split=2, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight, + f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias, + f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight, + f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias, + } + ) + + # .proj_out (1d conv) -> .proj_attn (linear) + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0 + ], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + ) + + return diffusers_checkpoint + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + assert weights[weights_biases_idx] is None + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +# done unet utils + + +def prior(*, args, checkpoint_map_location): + print("loading prior") + + prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) + + clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location) + + prior_model = prior_model_from_original_config() + + prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint( + prior_model, prior_checkpoint, clip_stats_checkpoint + ) + + del prior_checkpoint + del clip_stats_checkpoint + + load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True) + + print("done loading prior") + + return prior_model + + +def text2img(*, args, checkpoint_map_location): + print("loading text2img") + + text2img_checkpoint = torch.load(args.text2img_checkpoint_path, map_location=checkpoint_map_location) + + unet_model = unet_model_from_original_config() + + unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint) + + del text2img_checkpoint + + load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True) + + print("done loading text2img") + + return unet_model + + +def inpaint_text2img(*, args, checkpoint_map_location): + print("loading inpaint text2img") + + inpaint_text2img_checkpoint = torch.load( + args.inpaint_text2img_checkpoint_path, map_location=checkpoint_map_location + ) + + inpaint_unet_model = inpaint_unet_model_from_original_config() + + inpaint_unet_diffusers_checkpoint = inpaint_unet_original_checkpoint_to_diffusers_checkpoint( + inpaint_unet_model, inpaint_text2img_checkpoint + ) + + del inpaint_text2img_checkpoint + + load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True) + + print("done loading inpaint text2img") + + return inpaint_unet_model + + +# movq + +MOVQ_CONFIG = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), + "up_block_types": ("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + "num_vq_embeddings": 16384, + "block_out_channels": (128, 256, 256, 512), + "vq_embed_dim": 4, + "layers_per_block": 2, + "norm_type": "spatial", +} + + +def movq_model_from_original_config(): + movq = VQModel(**MOVQ_CONFIG) + return movq + + +def movq_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"], + "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"], + "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"], + "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"], + "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"], + "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def movq_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def movq_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.spatial_norm.norm_layer.weight": checkpoint[ + f"{attention_prefix}.norm.norm_layer.weight" + ], + f"{diffusers_attention_prefix}.spatial_norm.norm_layer.bias": checkpoint[ + f"{attention_prefix}.norm.norm_layer.bias" + ], + f"{diffusers_attention_prefix}.spatial_norm.conv_y.weight": checkpoint[ + f"{attention_prefix}.norm.conv_y.weight" + ], + f"{diffusers_attention_prefix}.spatial_norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"], + f"{diffusers_attention_prefix}.spatial_norm.conv_b.weight": checkpoint[ + f"{attention_prefix}.norm.conv_b.weight" + ], + f"{diffusers_attention_prefix}.spatial_norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + diffusers_checkpoint.update(movq_encoder_to_diffusers_checkpoint(model, checkpoint)) + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(movq_decoder_to_diffusers_checkpoint(model, checkpoint)) + + return diffusers_checkpoint + + +def movq(*, args, checkpoint_map_location): + print("loading movq") + + movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location) + + movq_model = movq_model_from_original_config() + + movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint(movq_model, movq_checkpoint) + + del movq_checkpoint + + load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True) + + print("done loading movq") + + return movq_model + + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile(delete=False) as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + os.remove(file.name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--prior_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the prior checkpoint to convert.", + ) + parser.add_argument( + "--clip_stat_path", + default=None, + type=str, + required=False, + help="Path to the clip stats checkpoint to convert.", + ) + parser.add_argument( + "--text2img_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the text2img checkpoint to convert.", + ) + parser.add_argument( + "--movq_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the text2img checkpoint to convert.", + ) + parser.add_argument( + "--inpaint_text2img_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the inpaint text2img checkpoint to convert.", + ) + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + if args.debug is not None: + print(f"debug: only executing {args.debug}") + + if args.debug is None: + print("to-do") + elif args.debug == "prior": + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + prior_model.save_pretrained(args.dump_path) + elif args.debug == "text2img": + unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) + unet_model.save_pretrained(f"{args.dump_path}/unet") + elif args.debug == "inpaint_text2img": + inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location) + inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") + elif args.debug == "decoder": + decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) + decoder.save_pretrained(f"{args.dump_path}/decoder") + else: + raise ValueError(f"unknown debug value : {args.debug}") diff --git a/diffusers/scripts/convert_ldm_original_checkpoint_to_diffusers.py b/diffusers/scripts/convert_ldm_original_checkpoint_to_diffusers.py new file mode 100755 index 0000000..ada7dc6 --- /dev/null +++ b/diffusers/scripts/convert_ldm_original_checkpoint_to_diffusers.py @@ -0,0 +1,359 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the LDM checkpoints.""" + +import argparse +import json + +import torch + +from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def convert_ldm_checkpoint(checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["num_res_blocks"] + 1) + layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1) + + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in checkpoint: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[ + f"input_blocks.{i}.0.op.weight" + ] + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[ + f"input_blocks.{i}.0.op.bias" + ] + continue + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"} + assign_to_checkpoint( + paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } + to_split = { + f"input_blocks.{i}.1.qkv.bias": { + "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", + "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", + "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", + }, + f"input_blocks.{i}.1.qkv.weight": { + "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", + "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", + "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", + }, + } + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[meta_path], + attention_paths_to_split=to_split, + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config) + + attentions_paths = renew_attention_paths(attentions) + to_split = { + "middle_block.1.qkv.bias": { + "key": "mid_block.attentions.0.key.bias", + "query": "mid_block.attentions.0.query.bias", + "value": "mid_block.attentions.0.value.bias", + }, + "middle_block.1.qkv.weight": { + "key": "mid_block.attentions.0.key.weight", + "query": "mid_block.attentions.0.query.weight", + "value": "mid_block.attentions.0.value.weight", + }, + } + assign_to_checkpoint( + attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["num_res_blocks"] + 1) + layer_in_block_id = i % (config["num_res_blocks"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config) + + if ["conv.weight", "conv.bias"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + to_split = { + f"output_blocks.{i}.1.qkv.bias": { + "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", + "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", + "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", + }, + f"output_blocks.{i}.1.qkv.weight": { + "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", + "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", + "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", + }, + } + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[meta_path], + attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None, + config=config, + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = checkpoint[old_path] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the architecture.", + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + checkpoint = torch.load(args.checkpoint_path) + + with open(args.config_file) as f: + config = json.loads(f.read()) + + converted_checkpoint = convert_ldm_checkpoint(checkpoint, config) + + if "ldm" in config: + del config["ldm"] + + model = UNet2DModel(**config) + model.load_state_dict(converted_checkpoint) + + try: + scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) + vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1])) + + pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae) + pipe.save_pretrained(args.dump_path) + except: # noqa: E722 + model.save_pretrained(args.dump_path) diff --git a/diffusers/scripts/convert_lora_safetensor_to_diffusers.py b/diffusers/scripts/convert_lora_safetensor_to_diffusers.py new file mode 100755 index 0000000..4237452 --- /dev/null +++ b/diffusers/scripts/convert_lora_safetensor_to_diffusers.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2024, Haofan Wang, Qixun Wang, All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion script for the LoRA's safetensors checkpoints.""" + +import argparse + +import torch +from safetensors.torch import load_file + +from diffusers import StableDiffusionPipeline + + +def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha): + # load base model + pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + + # load LoRA weight from .safetensors + state_dict = load_file(checkpoint_path) + + visited = [] + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) + + # update visited list + for item in pair_keys: + visited.append(item) + + return pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + args = parser.parse_args() + + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha + + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/diffusers/scripts/convert_lumina_to_diffusers.py b/diffusers/scripts/convert_lumina_to_diffusers.py new file mode 100755 index 0000000..a12625d --- /dev/null +++ b/diffusers/scripts/convert_lumina_to_diffusers.py @@ -0,0 +1,142 @@ +import argparse +import os + +import torch +from safetensors.torch import load_file +from transformers import AutoModel, AutoTokenizer + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline + + +def main(args): + # checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I + all_sd = load_file(args.origin_ckpt_path, device="cpu") + converted_state_dict = {} + # pad token + converted_state_dict["pad_token"] = all_sd["pad_token"] + + # patch embed + converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"] + converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"] + + # time and caption embed + converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"] + converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"] + converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"] + converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"] + converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"] + converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"] + converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"] + converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"] + + for i in range(24): + # adaln + converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"] + converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"] + converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"] + + # qkv + converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] + converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"] + converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"] + + # cap + converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] + converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"] + converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"] + + # output + converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"] + + # attention + # qk norm + converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] + converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] + + converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"] + converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"] + + converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] + converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] + + converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"] + converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"] + + # attention norm + converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"] + converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"] + converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"] + + # feed forward + converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"] + converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"] + converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"] + + # feed forward norm + converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"] + converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"] + + # final layer + converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"] + converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"] + + converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"] + converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"] + + # Lumina-Next-SFT 2B + transformer = LuminaNextDiT2DModel( + sample_size=128, + patch_size=2, + in_channels=4, + hidden_size=2304, + num_layers=24, + num_attention_heads=32, + num_kv_heads=8, + multiple_of=256, + ffn_dim_multiplier=None, + norm_eps=1e-5, + learn_sigma=True, + qk_norm=True, + cross_attention_dim=2048, + scaling_factor=1.0, + ) + transformer.load_state_dict(converted_state_dict, strict=True) + + num_model_params = sum(p.numel() for p in transformer.parameters()) + print(f"Total number of transformer parameters: {num_model_params}") + + if args.only_transformer: + transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) + else: + scheduler = FlowMatchEulerDiscreteScheduler() + + vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32) + + tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") + text_encoder = AutoModel.from_pretrained("google/gemma-2b") + + pipeline = LuminaText2ImgPipeline( + tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler + ) + pipeline.save_pretrained(args.dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--image_size", + default=1024, + type=int, + choices=[256, 512, 1024], + required=False, + help="Image size of pretrained model, either 512 or 1024.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--only_transformer", default=True, type=bool, required=True) + + args = parser.parse_args() + main(args) diff --git a/diffusers/scripts/convert_models_diffuser_to_diffusers.py b/diffusers/scripts/convert_models_diffuser_to_diffusers.py new file mode 100755 index 0000000..cc5321e --- /dev/null +++ b/diffusers/scripts/convert_models_diffuser_to_diffusers.py @@ -0,0 +1,100 @@ +import json +import os + +import torch + +from diffusers import UNet1DModel + + +os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) +os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) + +os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) + + +def unet(hor): + if hor == 128: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") + + elif hor == 32: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 64, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") + model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") + state_dict = model.state_dict() + config = { + "down_block_types": down_block_types, + "block_out_channels": block_out_channels, + "up_block_types": up_block_types, + "layers_per_block": 1, + "use_timestep_embedding": True, + "out_block_type": "OutConv1DBlock", + "norm_num_groups": 8, + "downsample_each_block": False, + "in_channels": 14, + "out_channels": 14, + "extra_in_channels": 0, + "time_embedding_type": "positional", + "flip_sin_to_cos": False, + "freq_shift": 1, + "sample_size": 65536, + "mid_block_type": "MidResTemporalBlock1D", + "act_fn": "mish", + } + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") + with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: + json.dump(config, f) + + +def value_function(): + config = { + "in_channels": 14, + "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + "up_block_types": (), + "out_block_type": "ValueFunction", + "mid_block_type": "ValueFunctionMidBlock1D", + "block_out_channels": (32, 64, 128, 256), + "layers_per_block": 1, + "downsample_each_block": True, + "sample_size": 65536, + "out_channels": 14, + "extra_in_channels": 0, + "time_embedding_type": "positional", + "use_timestep_embedding": True, + "flip_sin_to_cos": False, + "freq_shift": 1, + "norm_num_groups": 8, + "act_fn": "mish", + } + + model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") + state_dict = model + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + + mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") + with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: + json.dump(config, f) + + +if __name__ == "__main__": + unet(32) + # unet(128) + value_function() diff --git a/diffusers/scripts/convert_ms_text_to_video_to_diffusers.py b/diffusers/scripts/convert_ms_text_to_video_to_diffusers.py new file mode 100755 index 0000000..0251ab6 --- /dev/null +++ b/diffusers/scripts/convert_ms_text_to_video_to_diffusers.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the LDM checkpoints.""" + +import argparse + +import torch + +from diffusers import UNet3DConditionModel + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + weight = old_checkpoint[path["old"]] + names = ["proj_attn.weight"] + names_2 = ["proj_out.weight", "proj_in.weight"] + if any(k in new_path for k in names): + checkpoint[new_path] = weight[:, :, 0] + elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path: + checkpoint[new_path] = weight[:, :, 0] + else: + checkpoint[new_path] = weight + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + mapping.append({"old": old_item, "new": old_item}) + + return mapping + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + if "temopral_conv" not in old_item: + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")] + paths = renew_attention_paths(first_temp_attention) + meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key] + + if f"input_blocks.{i}.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [key for key in resnets if "temopral_conv" in key] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = { + "old": f"input_blocks.{i}.0.temopral_conv", + "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = { + "old": f"input_blocks.{i}.2", + "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key] + attentions = middle_blocks[1] + temp_attentions = middle_blocks[2] + resnet_1 = middle_blocks[3] + temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key] + + resnet_0_paths = renew_resnet_paths(resnet_0) + meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"} + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0) + meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"} + assign_to_checkpoint( + temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + resnet_1_paths = renew_resnet_paths(resnet_1) + meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"} + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1) + meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"} + assign_to_checkpoint( + temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path] + ) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temp_attentions_paths = renew_attention_paths(temp_attentions) + meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"} + assign_to_checkpoint( + temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + temporal_convs = [key for key in resnets if "temopral_conv" in key] + paths = renew_temp_conv_paths(temporal_convs) + meta_path = { + "old": f"output_blocks.{i}.0.temopral_conv", + "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(temp_attentions): + paths = renew_attention_paths(temp_attentions) + meta_path = { + "old": f"output_blocks.{i}.2", + "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l] + for path in temopral_conv_paths: + pruned_path = path.split("temopral_conv.")[-1] + old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path]) + new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path]) + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + unet = UNet3DConditionModel() + + converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config) + + diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys()) + diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys()) + + assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match" + + # load state_dict + unet.load_state_dict(converted_ckpt) + + unet.save_pretrained(args.dump_path) + + # -- finish converting the unet -- diff --git a/diffusers/scripts/convert_music_spectrogram_to_diffusers.py b/diffusers/scripts/convert_music_spectrogram_to_diffusers.py new file mode 100755 index 0000000..1ac9ced --- /dev/null +++ b/diffusers/scripts/convert_music_spectrogram_to_diffusers.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +import argparse +import os + +import jax as jnp +import numpy as onp +import torch +import torch.nn as nn +from music_spectrogram_diffusion import inference +from t5x import checkpoints + +from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline +from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder + + +MODEL = "base_with_context" + + +def load_notes_encoder(weights, model): + model.token_embedder.weight = nn.Parameter(torch.Tensor(weights["token_embedder"]["embedding"])) + model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False) + for lyr_num, lyr in enumerate(model.encoders): + ly_weight = weights[f"layers_{lyr_num}"] + lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"])) + + attention_weights = ly_weight["attention"] + lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T)) + lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T)) + lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T)) + lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T)) + + lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"])) + + lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T)) + + model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"])) + return model + + +def load_continuous_encoder(weights, model): + model.input_proj.weight = nn.Parameter(torch.Tensor(weights["input_proj"]["kernel"].T)) + + model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False) + + for lyr_num, lyr in enumerate(model.encoders): + ly_weight = weights[f"layers_{lyr_num}"] + attention_weights = ly_weight["attention"] + + lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T)) + lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T)) + lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T)) + lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T)) + lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"])) + + lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T)) + lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"])) + + model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"])) + + return model + + +def load_decoder(weights, model): + model.conditioning_emb[0].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense0"]["kernel"].T)) + model.conditioning_emb[2].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense1"]["kernel"].T)) + + model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False) + + model.continuous_inputs_projection.weight = nn.Parameter( + torch.Tensor(weights["continuous_inputs_projection"]["kernel"].T) + ) + + for lyr_num, lyr in enumerate(model.decoders): + ly_weight = weights[f"layers_{lyr_num}"] + lyr.layer[0].layer_norm.weight = nn.Parameter( + torch.Tensor(ly_weight["pre_self_attention_layer_norm"]["scale"]) + ) + + lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter( + torch.Tensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T) + ) + + attention_weights = ly_weight["self_attention"] + lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T)) + lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T)) + lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T)) + lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T)) + + attention_weights = ly_weight["MultiHeadDotProductAttention_0"] + lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T)) + lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T)) + lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T)) + lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T)) + lyr.layer[1].layer_norm.weight = nn.Parameter( + torch.Tensor(ly_weight["pre_cross_attention_layer_norm"]["scale"]) + ) + + lyr.layer[2].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"])) + lyr.layer[2].film.scale_bias.weight = nn.Parameter( + torch.Tensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T) + ) + lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) + lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) + lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T)) + + model.decoder_norm.weight = nn.Parameter(torch.Tensor(weights["decoder_norm"]["scale"])) + + model.spec_out.weight = nn.Parameter(torch.Tensor(weights["spec_out_dense"]["kernel"].T)) + + return model + + +def main(args): + t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path) + t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint) + + gin_overrides = [ + "from __gin__ import dynamic_registration", + "from music_spectrogram_diffusion.models.diffusion import diffusion_utils", + "diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0", + "diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()", + ] + + gin_file = os.path.join(args.checkpoint_path, "..", "config.gin") + gin_config = inference.parse_training_gin_file(gin_file, gin_overrides) + synth_model = inference.InferenceModel(args.checkpoint_path, gin_config) + + scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large") + + notes_encoder = SpectrogramNotesEncoder( + max_length=synth_model.sequence_length["inputs"], + vocab_size=synth_model.model.module.config.vocab_size, + d_model=synth_model.model.module.config.emb_dim, + dropout_rate=synth_model.model.module.config.dropout_rate, + num_layers=synth_model.model.module.config.num_encoder_layers, + num_heads=synth_model.model.module.config.num_heads, + d_kv=synth_model.model.module.config.head_dim, + d_ff=synth_model.model.module.config.mlp_dim, + feed_forward_proj="gated-gelu", + ) + + continuous_encoder = SpectrogramContEncoder( + input_dims=synth_model.audio_codec.n_dims, + targets_context_length=synth_model.sequence_length["targets_context"], + d_model=synth_model.model.module.config.emb_dim, + dropout_rate=synth_model.model.module.config.dropout_rate, + num_layers=synth_model.model.module.config.num_encoder_layers, + num_heads=synth_model.model.module.config.num_heads, + d_kv=synth_model.model.module.config.head_dim, + d_ff=synth_model.model.module.config.mlp_dim, + feed_forward_proj="gated-gelu", + ) + + decoder = T5FilmDecoder( + input_dims=synth_model.audio_codec.n_dims, + targets_length=synth_model.sequence_length["targets_context"], + max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time, + d_model=synth_model.model.module.config.emb_dim, + num_layers=synth_model.model.module.config.num_decoder_layers, + num_heads=synth_model.model.module.config.num_heads, + d_kv=synth_model.model.module.config.head_dim, + d_ff=synth_model.model.module.config.mlp_dim, + dropout_rate=synth_model.model.module.config.dropout_rate, + ) + + notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder) + continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder) + decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder) + + melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder") + + pipe = SpectrogramDiffusionPipeline( + notes_encoder=notes_encoder, + continuous_encoder=continuous_encoder, + decoder=decoder, + scheduler=scheduler, + melgan=melgan, + ) + if args.save: + pipe.save_pretrained(args.output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.") + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not." + ) + parser.add_argument( + "--checkpoint_path", + default=f"{MODEL}/checkpoint_500000", + type=str, + required=False, + help="Path to the original jax model checkpoint.", + ) + args = parser.parse_args() + + main(args) diff --git a/diffusers/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py b/diffusers/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py new file mode 100755 index 0000000..2d67123 --- /dev/null +++ b/diffusers/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the NCSNPP checkpoints.""" + +import argparse +import json + +import torch + +from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel + + +def convert_ncsnpp_checkpoint(checkpoint, config): + """ + Takes a state dict and the path to + """ + new_model_architecture = UNet2DModel(**config) + new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data + new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data + new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data + new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data + + new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data + new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data + + new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data + new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data + + new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data + new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data + new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data + new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data + + module_index = 4 + + def set_attention_weights(new_layer, old_checkpoint, index): + new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T + new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T + new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T + + new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data + new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data + new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data + + new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T + new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data + + new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data + new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data + + def set_resnet_weights(new_layer, old_checkpoint, index): + new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data + new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data + new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data + new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data + + new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data + new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data + new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data + new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data + + new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data + new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data + + if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down: + new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data + new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data + + for i, block in enumerate(new_model_architecture.downsample_blocks): + has_attentions = hasattr(block, "attentions") + for j in range(len(block.resnets)): + set_resnet_weights(block.resnets[j], checkpoint, module_index) + module_index += 1 + if has_attentions: + set_attention_weights(block.attentions[j], checkpoint, module_index) + module_index += 1 + + if hasattr(block, "downsamplers") and block.downsamplers is not None: + set_resnet_weights(block.resnet_down, checkpoint, module_index) + module_index += 1 + block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data + block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data + module_index += 1 + + set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index) + module_index += 1 + set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index) + module_index += 1 + set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index) + module_index += 1 + + for i, block in enumerate(new_model_architecture.up_blocks): + has_attentions = hasattr(block, "attentions") + for j in range(len(block.resnets)): + set_resnet_weights(block.resnets[j], checkpoint, module_index) + module_index += 1 + if has_attentions: + set_attention_weights( + block.attentions[0], checkpoint, module_index + ) # why can there only be a single attention layer for up? + module_index += 1 + + if hasattr(block, "resnet_up") and block.resnet_up is not None: + block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + module_index += 1 + block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + module_index += 1 + set_resnet_weights(block.resnet_up, checkpoint, module_index) + module_index += 1 + + new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + module_index += 1 + new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + + return new_model_architecture.state_dict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", + default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin", + type=str, + required=False, + help="Path to the checkpoint to convert.", + ) + + parser.add_argument( + "--config_file", + default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json", + type=str, + required=False, + help="The config json file corresponding to the architecture.", + ) + + parser.add_argument( + "--dump_path", + default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt", + type=str, + required=False, + help="Path to the output model.", + ) + + args = parser.parse_args() + + checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + + with open(args.config_file) as f: + config = json.loads(f.read()) + + converted_checkpoint = convert_ncsnpp_checkpoint( + checkpoint, + config, + ) + + if "sde" in config: + del config["sde"] + + model = UNet2DModel(**config) + model.load_state_dict(converted_checkpoint) + + try: + scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) + + pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler) + pipe.save_pretrained(args.dump_path) + except: # noqa: E722 + model.save_pretrained(args.dump_path) diff --git a/diffusers/scripts/convert_original_audioldm2_to_diffusers.py b/diffusers/scripts/convert_original_audioldm2_to_diffusers.py new file mode 100755 index 0000000..ea9c02d --- /dev/null +++ b/diffusers/scripts/convert_original_audioldm2_to_diffusers.py @@ -0,0 +1,1135 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the AudioLDM2 checkpoints.""" + +import argparse +import re +from typing import List, Union + +import torch +import yaml +from transformers import ( + AutoFeatureExtractor, + AutoTokenizer, + ClapConfig, + ClapModel, + GPT2Config, + GPT2Model, + SpeechT5HifiGan, + SpeechT5HifiGanConfig, + T5Config, + T5EncoderModel, +) + +from diffusers import ( + AudioLDM2Pipeline, + AudioLDM2ProjectionModel, + AudioLDM2UNet2DConditionModel, + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import is_safetensors_available +from diffusers.utils.import_utils import BACKENDS_MAPPING + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths +def renew_attention_paths(old_list): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"] + proj_key = "to_out.0.weight" + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys or ".".join(key.split(".")[-3:]) == proj_key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key].squeeze() + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a UNet config for diffusers based on the config of the original AudioLDM2 model. + """ + unet_params = original_config["model"]["params"]["unet_config"]["params"] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + cross_attention_dim = list(unet_params["context_dim"]) if "context_dim" in unet_params else block_out_channels + if len(cross_attention_dim) > 1: + # require two or more cross-attention layers per-block, each of different dimension + cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "out_channels": unet_params["out_channels"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "transformer_layers_per_block": unet_params["transformer_depth"], + "cross_attention_dim": tuple(cross_attention_dim), + } + + return config + + +# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config +def create_vae_diffusers_config(original_config, checkpoint, image_size: int): + """ + Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original + Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215 + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + "scaling_factor": float(scaling_factor), + } + return config + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], + beta_schedule="scaled_linear", + ) + return schedular + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted UNet checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + # strip the unet prefix from the weight names + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] + for layer_id in range(num_output_blocks) + } + + # Check how many Transformer blocks we have per layer + if isinstance(config.get("cross_attention_dim"), (list, tuple)): + if isinstance(config["cross_attention_dim"][0], (list, tuple)): + # in this case we have multiple cross-attention layers per-block + num_attention_layers = len(config.get("cross_attention_dim")[0]) + else: + num_attention_layers = 1 + + if config.get("extra_self_attn_layer"): + num_attention_layers += 1 + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.0" not in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = [ + { + "old": f"input_blocks.{i}.{1 + layer_id}", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id * num_attention_layers + layer_id}", + } + for layer_id in range(num_attention_layers) + ] + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=meta_path, config=config + ) + + resnet_0 = middle_blocks[0] + resnet_1 = middle_blocks[num_middle_blocks - 1] + + resnet_0_paths = renew_resnet_paths(resnet_0) + meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"} + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_1_paths = renew_resnet_paths(resnet_1) + meta_path = {"old": f"middle_block.{len(middle_blocks) - 1}", "new": "mid_block.resnets.1"} + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(1, num_middle_blocks - 1): + attentions = middle_blocks[i] + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": f"middle_block.{i}", "new": f"mid_block.attentions.{i - 1}"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.0" not in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + attentions.remove(f"output_blocks.{i}.{index}.conv.bias") + attentions.remove(f"output_blocks.{i}.{index}.conv.weight") + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = [ + { + "old": f"output_blocks.{i}.{1 + layer_id}", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id * num_attention_layers + layer_id}", + } + for layer_id in range(num_attention_layers) + ] + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=meta_path, config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +CLAP_KEYS_TO_MODIFY_MAPPING = { + "text_branch": "text_model", + "audio_branch": "audio_model.audio_encoder", + "attn": "attention.self", + "self.proj": "output.dense", + "attention.self_mask": "attn_mask", + "mlp.fc1": "intermediate.dense", + "mlp.fc2": "output.dense", + "norm1": "layernorm_before", + "norm2": "layernorm_after", + "bn0": "batch_norm", +} + +CLAP_KEYS_TO_IGNORE = [ + "text_transform", + "audio_transform", + "stft", + "logmel_extractor", + "tscam_conv", + "head", + "attn_mask", +] + +CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"] + + +def convert_open_clap_checkpoint(checkpoint): + """ + Takes a state dict and returns a converted CLAP checkpoint. + """ + # extract state dict for CLAP text embedding model, discarding the audio component + model_state_dict = {} + model_key = "clap.model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(model_key): + model_state_dict[key.replace(model_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + sequential_layers_pattern = r".*sequential.(\d+).*" + text_projection_pattern = r".*_projection.(\d+).*" + + for key, value in model_state_dict.items(): + # check if key should be ignored in mapping - if so map it to a key name that we'll filter out at the end + for key_to_ignore in CLAP_KEYS_TO_IGNORE: + if key_to_ignore in key: + key = "spectrogram" + + # check if any key needs to be modified + for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + if re.match(sequential_layers_pattern, key): + # replace sequential layers with list + sequential_layer = re.match(sequential_layers_pattern, key).group(1) + + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") + elif re.match(text_projection_pattern, key): + projecton_layer = int(re.match(text_projection_pattern, key).group(1)) + + # Because in CLAP they use `nn.Sequential`... + transformers_projection_layer = 1 if projecton_layer == 0 else 2 + + key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.") + + if "audio" and "qkv" in key: + # split qkv into query key and value + mixed_qkv = value + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + + new_checkpoint[key.replace("qkv", "query")] = query_layer + new_checkpoint[key.replace("qkv", "key")] = key_layer + new_checkpoint[key.replace("qkv", "value")] = value_layer + elif key != "spectrogram": + new_checkpoint[key] = value + + return new_checkpoint + + +def create_transformers_vocoder_config(original_config): + """ + Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. + """ + vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"] + + config = { + "model_in_dim": vocoder_params["num_mels"], + "sampling_rate": vocoder_params["sampling_rate"], + "upsample_initial_channel": vocoder_params["upsample_initial_channel"], + "upsample_rates": list(vocoder_params["upsample_rates"]), + "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]), + "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]), + "resblock_dilation_sizes": [ + list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"] + ], + "normalize_before": False, + } + + return config + + +def extract_sub_model(checkpoint, key_prefix): + """ + Takes a state dict and returns the state dict for a particular sub-model. + """ + + sub_model_state_dict = {} + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(key_prefix): + sub_model_state_dict[key.replace(key_prefix, "")] = checkpoint.get(key) + + return sub_model_state_dict + + +def convert_hifigan_checkpoint(checkpoint, config): + """ + Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint. + """ + # extract state dict for vocoder + vocoder_state_dict = extract_sub_model(checkpoint, key_prefix="first_stage_model.vocoder.") + + # fix upsampler keys, everything else is correct already + for i in range(len(config.upsample_rates)): + vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight") + vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias") + + if not config.normalize_before: + # if we don't set normalize_before then these variables are unused, so we set them to their initialised values + vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim) + vocoder_state_dict["scale"] = torch.ones(config.model_in_dim) + + return vocoder_state_dict + + +def convert_projection_checkpoint(checkpoint): + projection_state_dict = {} + conditioner_state_dict = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.") + + projection_state_dict["sos_embed"] = conditioner_state_dict["start_of_sequence_tokens.weight"][0] + projection_state_dict["sos_embed_1"] = conditioner_state_dict["start_of_sequence_tokens.weight"][1] + + projection_state_dict["eos_embed"] = conditioner_state_dict["end_of_sequence_tokens.weight"][0] + projection_state_dict["eos_embed_1"] = conditioner_state_dict["end_of_sequence_tokens.weight"][1] + + projection_state_dict["projection.weight"] = conditioner_state_dict["input_sequence_embed_linear.0.weight"] + projection_state_dict["projection.bias"] = conditioner_state_dict["input_sequence_embed_linear.0.bias"] + + projection_state_dict["projection_1.weight"] = conditioner_state_dict["input_sequence_embed_linear.1.weight"] + projection_state_dict["projection_1.bias"] = conditioner_state_dict["input_sequence_embed_linear.1.bias"] + + return projection_state_dict + + +# Adapted from https://github.com/haoheliu/AudioLDM2/blob/81ad2c6ce015c1310387695e2dae975a7d2ed6fd/audioldm2/utils.py#L143 +DEFAULT_CONFIG = { + "model": { + "params": { + "linear_start": 0.0015, + "linear_end": 0.0195, + "timesteps": 1000, + "channels": 8, + "scale_by_std": True, + "unet_config": { + "target": "audioldm2.latent_diffusion.openaimodel.UNetModel", + "params": { + "context_dim": [None, 768, 1024], + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + "transformer_depth": 1, + }, + }, + "first_stage_config": { + "target": "audioldm2.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "embed_dim": 8, + "ddconfig": { + "z_channels": 8, + "resolution": 256, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + }, + }, + }, + "cond_stage_config": { + "crossattn_audiomae_generated": { + "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond", + "params": { + "sequence_gen_length": 8, + "sequence_input_embed_dim": [512, 1024], + }, + } + }, + "vocoder_config": { + "target": "audioldm2.first_stage_model.vocoder", + "params": { + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "num_mels": 64, + "sampling_rate": 16000, + }, + }, + }, + }, +} + + +def load_pipeline_from_original_AudioLDM2_ckpt( + checkpoint_path: str, + original_config_file: str = None, + image_size: int = 1024, + prediction_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "ddim", + cross_attention_dim: Union[List, List[List]] = None, + transformer_layers_per_block: int = None, + device: str = None, + from_safetensors: bool = False, +) -> AudioLDM2Pipeline: + """ + Load an AudioLDM2 pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + set to the AudioLDM2 base config. + image_size (`int`, *optional*, defaults to 1024): + The image size that the model was trained on. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. If `None`, will be automatically + inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`. + scheduler_type (`str`, *optional*, defaults to 'ddim'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + cross_attention_dim (`list`, *optional*, defaults to `None`): + The dimension of the cross-attention layers. If `None`, the cross-attention dimension will be + automatically inferred. Set to `[768, 1024]` for the base model, or `[768, 1024, None]` for the large model. + transformer_layers_per_block (`int`, *optional*, defaults to `None`): + The number of transformer layers in each transformer block. If `None`, number of layers will be " + "automatically inferred. Set to `1` for the base model, or `2` for the large model. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + if from_safetensors: + if not is_safetensors_available(): + raise ValueError(BACKENDS_MAPPING["safetensors"][1]) + + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + original_config = DEFAULT_CONFIG + else: + original_config = yaml.safe_load(original_config_file) + + if image_size is not None: + original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size + + if cross_attention_dim is not None: + original_config["model"]["params"]["unet_config"]["params"]["context_dim"] = cross_attention_dim + + if transformer_layers_per_block is not None: + original_config["model"]["params"]["unet_config"]["params"]["transformer_depth"] = transformer_layers_per_block + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + prediction_type = "v_prediction" + else: + if prediction_type is None: + prediction_type = "epsilon" + + num_train_timesteps = original_config["model"]["params"]["timesteps"] + beta_start = original_config["model"]["params"]["linear_start"] + beta_end = original_config["model"]["params"]["linear_end"] + + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DModel + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet = AudioLDM2UNet2DConditionModel(**unet_config) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model + vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + # Convert the joint audio-text encoding model + clap_config = ClapConfig.from_pretrained("laion/clap-htsat-unfused") + clap_config.audio_config.update( + { + "patch_embeds_hidden_size": 128, + "hidden_size": 1024, + "depths": [2, 2, 12, 2], + } + ) + # AudioLDM2 uses the same tokenizer and feature extractor as the original CLAP model + clap_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + clap_feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused") + + converted_clap_model = convert_open_clap_checkpoint(checkpoint) + clap_model = ClapModel(clap_config) + + missing_keys, unexpected_keys = clap_model.load_state_dict(converted_clap_model, strict=False) + # we expect not to have token_type_ids in our original state dict so let's ignore them + missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS)) + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}") + + if len(missing_keys) > 0: + raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}") + + # Convert the vocoder model + vocoder_config = create_transformers_vocoder_config(original_config) + vocoder_config = SpeechT5HifiGanConfig(**vocoder_config) + converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config) + + vocoder = SpeechT5HifiGan(vocoder_config) + vocoder.load_state_dict(converted_vocoder_checkpoint) + + # Convert the Flan-T5 encoder model: AudioLDM2 uses the same configuration and tokenizer as the original Flan-T5 large model + t5_config = T5Config.from_pretrained("google/flan-t5-large") + converted_t5_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.1.model.") + + t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") + # hard-coded in the original implementation (i.e. not retrievable from the config) + t5_tokenizer.model_max_length = 128 + t5_model = T5EncoderModel(t5_config) + t5_model.load_state_dict(converted_t5_checkpoint) + + # Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model + gpt2_config = GPT2Config.from_pretrained("gpt2") + gpt2_model = GPT2Model(gpt2_config) + gpt2_model.config.max_new_tokens = original_config["model"]["params"]["cond_stage_config"][ + "crossattn_audiomae_generated" + ]["params"]["sequence_gen_length"] + + converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.") + gpt2_model.load_state_dict(converted_gpt2_checkpoint) + + # Convert the extra embedding / projection layers + projection_model = AudioLDM2ProjectionModel(clap_config.projection_dim, t5_config.d_model, gpt2_config.n_embd) + + converted_projection_checkpoint = convert_projection_checkpoint(checkpoint) + projection_model.load_state_dict(converted_projection_checkpoint) + + # Instantiate the diffusers pipeline + pipe = AudioLDM2Pipeline( + vae=vae, + text_encoder=clap_model, + text_encoder_2=t5_model, + projection_model=projection_model, + language_model=gpt2_model, + tokenizer=clap_tokenizer, + tokenizer_2=t5_tokenizer, + feature_extractor=clap_feature_extractor, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + + return pipe + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--cross_attention_dim", + default=None, + type=int, + nargs="+", + help="The dimension of the cross-attention layers. If `None`, the cross-attention dimension will be " + "automatically inferred. Set to `768+1024` for the base model, or `768+1024+640` for the large model", + ) + parser.add_argument( + "--transformer_layers_per_block", + default=None, + type=int, + help="The number of transformer layers in each transformer block. If `None`, number of layers will be " + "automatically inferred. Set to `1` for the base model, or `2` for the large model.", + ) + parser.add_argument( + "--scheduler_type", + default="ddim", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--image_size", + default=1048, + type=int, + help="The image size that the model was trained on.", + ) + parser.add_argument( + "--prediction_type", + default=None, + type=str, + help=("The prediction type that the model was trained on."), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + args = parser.parse_args() + + pipe = load_pipeline_from_original_AudioLDM2_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + prediction_type=args.prediction_type, + extract_ema=args.extract_ema, + scheduler_type=args.scheduler_type, + cross_attention_dim=args.cross_attention_dim, + transformer_layers_per_block=args.transformer_layers_per_block, + from_safetensors=args.from_safetensors, + device=args.device, + ) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/diffusers/scripts/convert_original_audioldm_to_diffusers.py b/diffusers/scripts/convert_original_audioldm_to_diffusers.py new file mode 100755 index 0000000..797d198 --- /dev/null +++ b/diffusers/scripts/convert_original_audioldm_to_diffusers.py @@ -0,0 +1,1042 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the AudioLDM checkpoints.""" + +import argparse +import re + +import torch +import yaml +from transformers import ( + AutoTokenizer, + ClapTextConfig, + ClapTextModelWithProjection, + SpeechT5HifiGan, + SpeechT5HifiGanConfig, +) + +from diffusers import ( + AudioLDMPipeline, + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths +def renew_attention_paths(old_list): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a UNet config for diffusers based on the config of the original AudioLDM model. + """ + unet_params = original_config["model"]["params"]["unet_config"]["params"] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + cross_attention_dim = ( + unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels + ) + + class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None + projection_class_embeddings_input_dim = ( + unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None + ) + class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "out_channels": unet_params["out_channels"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": cross_attention_dim, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "class_embeddings_concat": class_embeddings_concat, + } + + return config + + +# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config +def create_vae_diffusers_config(original_config, checkpoint, image_size: int): + """ + Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original + Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215 + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + "scaling_factor": float(scaling_factor), + } + return config + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], + beta_schedule="scaled_linear", + ) + return schedular + + +# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion + conversion, this function additionally converts the learnt film embedding linear layer. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"] + new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +CLAP_KEYS_TO_MODIFY_MAPPING = { + "text_branch": "text_model", + "attn": "attention.self", + "self.proj": "output.dense", + "attention.self_mask": "attn_mask", + "mlp.fc1": "intermediate.dense", + "mlp.fc2": "output.dense", + "norm1": "layernorm_before", + "norm2": "layernorm_after", + "bn0": "batch_norm", +} + +CLAP_KEYS_TO_IGNORE = ["text_transform"] + +CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"] + + +def convert_open_clap_checkpoint(checkpoint): + """ + Takes a state dict and returns a converted CLAP checkpoint. + """ + # extract state dict for CLAP text embedding model, discarding the audio component + model_state_dict = {} + model_key = "cond_stage_model.model.text_" + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(model_key): + model_state_dict[key.replace(model_key, "text_")] = checkpoint.get(key) + + new_checkpoint = {} + + sequential_layers_pattern = r".*sequential.(\d+).*" + text_projection_pattern = r".*_projection.(\d+).*" + + for key, value in model_state_dict.items(): + # check if key should be ignored in mapping + if key.split(".")[0] in CLAP_KEYS_TO_IGNORE: + continue + + # check if any key needs to be modified + for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + if re.match(sequential_layers_pattern, key): + # replace sequential layers with list + sequential_layer = re.match(sequential_layers_pattern, key).group(1) + + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") + elif re.match(text_projection_pattern, key): + projecton_layer = int(re.match(text_projection_pattern, key).group(1)) + + # Because in CLAP they use `nn.Sequential`... + transformers_projection_layer = 1 if projecton_layer == 0 else 2 + + key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.") + + if "audio" and "qkv" in key: + # split qkv into query key and value + mixed_qkv = value + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + + new_checkpoint[key.replace("qkv", "query")] = query_layer + new_checkpoint[key.replace("qkv", "key")] = key_layer + new_checkpoint[key.replace("qkv", "value")] = value_layer + else: + new_checkpoint[key] = value + + return new_checkpoint + + +def create_transformers_vocoder_config(original_config): + """ + Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. + """ + vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"] + + config = { + "model_in_dim": vocoder_params["num_mels"], + "sampling_rate": vocoder_params["sampling_rate"], + "upsample_initial_channel": vocoder_params["upsample_initial_channel"], + "upsample_rates": list(vocoder_params["upsample_rates"]), + "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]), + "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]), + "resblock_dilation_sizes": [ + list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"] + ], + "normalize_before": False, + } + + return config + + +def convert_hifigan_checkpoint(checkpoint, config): + """ + Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint. + """ + # extract state dict for vocoder + vocoder_state_dict = {} + vocoder_key = "first_stage_model.vocoder." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vocoder_key): + vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key) + + # fix upsampler keys, everything else is correct already + for i in range(len(config.upsample_rates)): + vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight") + vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias") + + if not config.normalize_before: + # if we don't set normalize_before then these variables are unused, so we set them to their initialised values + vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim) + vocoder_state_dict["scale"] = torch.ones(config.model_in_dim) + + return vocoder_state_dict + + +# Adapted from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/audioldm/utils.py#L72-L73 +DEFAULT_CONFIG = { + "model": { + "params": { + "linear_start": 0.0015, + "linear_end": 0.0195, + "timesteps": 1000, + "channels": 8, + "scale_by_std": True, + "unet_config": { + "target": "audioldm.latent_diffusion.openaimodel.UNetModel", + "params": { + "extra_film_condition_dim": 512, + "extra_film_use_concat": True, + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + }, + }, + "first_stage_config": { + "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "embed_dim": 8, + "ddconfig": { + "z_channels": 8, + "resolution": 256, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + }, + }, + }, + "vocoder_config": { + "target": "audioldm.first_stage_model.vocoder", + "params": { + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "num_mels": 64, + "sampling_rate": 16000, + }, + }, + }, + }, +} + + +def load_pipeline_from_original_audioldm_ckpt( + checkpoint_path: str, + original_config_file: str = None, + image_size: int = 512, + prediction_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "ddim", + num_in_channels: int = None, + model_channels: int = None, + num_head_channels: int = None, + device: str = None, + from_safetensors: bool = False, +) -> AudioLDMPipeline: + """ + Load an AudioLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + set to the audioldm-s-full-v2 config. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. If `None`, will be automatically + inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`. + num_in_channels (`int`, *optional*, defaults to None): + The number of UNet input channels. If `None`, it will be automatically inferred from the config. + model_channels (`int`, *optional*, defaults to None): + The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override + to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large. + num_head_channels (`int`, *optional*, defaults to None): + The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override + to 32 for the small and medium checkpoints, and 64 for the large. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + original_config = DEFAULT_CONFIG + else: + original_config = yaml.safe_load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if model_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels + + if num_head_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + prediction_type = "v_prediction" + else: + if prediction_type is None: + prediction_type = "epsilon" + + if image_size is None: + image_size = 512 + + num_train_timesteps = original_config["model"]["params"]["timesteps"] + beta_start = original_config["model"]["params"]["linear_start"] + beta_end = original_config["model"]["params"]["linear_end"] + + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DModel + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet = UNet2DConditionModel(**unet_config) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model + vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + # Convert the text model + # AudioLDM uses the same configuration and tokenizer as the original CLAP model + config = ClapTextConfig.from_pretrained("laion/clap-htsat-unfused") + tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + + converted_text_model = convert_open_clap_checkpoint(checkpoint) + text_model = ClapTextModelWithProjection(config) + + missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False) + # we expect not to have token_type_ids in our original state dict so let's ignore them + missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS)) + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}") + + if len(missing_keys) > 0: + raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}") + + # Convert the vocoder model + vocoder_config = create_transformers_vocoder_config(original_config) + vocoder_config = SpeechT5HifiGanConfig(**vocoder_config) + converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config) + + vocoder = SpeechT5HifiGan(vocoder_config) + vocoder.load_state_dict(converted_vocoder_checkpoint) + + # Instantiate the diffusers pipeline + pipe = AudioLDMPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + + return pipe + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--model_channels", + default=None, + type=int, + help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override" + " to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.", + ) + parser.add_argument( + "--num_head_channels", + default=None, + type=int, + help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override" + " to 32 for the small and medium checkpoints, and 64 for the large.", + ) + parser.add_argument( + "--scheduler_type", + default="ddim", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--image_size", + default=None, + type=int, + help=("The image size that the model was trained on."), + ) + parser.add_argument( + "--prediction_type", + default=None, + type=str, + help=("The prediction type that the model was trained on."), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + args = parser.parse_args() + + pipe = load_pipeline_from_original_audioldm_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + prediction_type=args.prediction_type, + extract_ema=args.extract_ema, + scheduler_type=args.scheduler_type, + num_in_channels=args.num_in_channels, + model_channels=args.model_channels, + num_head_channels=args.num_head_channels, + from_safetensors=args.from_safetensors, + device=args.device, + ) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/diffusers/scripts/convert_original_controlnet_to_diffusers.py b/diffusers/scripts/convert_original_controlnet_to_diffusers.py new file mode 100755 index 0000000..92aad4f --- /dev/null +++ b/diffusers/scripts/convert_original_controlnet_to_diffusers.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet.""" + +import argparse + +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--image_size", + default=512, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + # small workaround to get argparser to parse a boolean input as either true _or_ false + def parse_bool(string): + if string == "True": + return True + elif string == "False": + return False + else: + raise ValueError(f"could not parse string as bool {string}") + + parser.add_argument( + "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool + ) + + parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) + + args = parser.parse_args() + + controlnet = download_controlnet_from_original_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + extract_ema=args.extract_ema, + num_in_channels=args.num_in_channels, + upcast_attention=args.upcast_attention, + from_safetensors=args.from_safetensors, + device=args.device, + use_linear_projection=args.use_linear_projection, + cross_attention_dim=args.cross_attention_dim, + ) + + controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/diffusers/scripts/convert_original_musicldm_to_diffusers.py b/diffusers/scripts/convert_original_musicldm_to_diffusers.py new file mode 100755 index 0000000..6db9dbd --- /dev/null +++ b/diffusers/scripts/convert_original_musicldm_to_diffusers.py @@ -0,0 +1,1056 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the MusicLDM checkpoints.""" + +import argparse +import re + +import torch +import yaml +from transformers import ( + AutoFeatureExtractor, + AutoTokenizer, + ClapConfig, + ClapModel, + SpeechT5HifiGan, + SpeechT5HifiGanConfig, +) + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + MusicLDMPipeline, + PNDMScheduler, + UNet2DConditionModel, +) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths +def renew_attention_paths(old_list): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"] + proj_key = "to_out.0.weight" + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys or ".".join(key.split(".")[-3:]) == proj_key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key].squeeze() + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a UNet config for diffusers based on the config of the original MusicLDM model. + """ + unet_params = original_config["model"]["params"]["unet_config"]["params"] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + cross_attention_dim = ( + unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels + ) + + class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None + projection_class_embeddings_input_dim = ( + unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None + ) + class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "out_channels": unet_params["out_channels"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": cross_attention_dim, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "class_embeddings_concat": class_embeddings_concat, + } + + return config + + +# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config +def create_vae_diffusers_config(original_config, checkpoint, image_size: int): + """ + Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original + Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215 + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + "scaling_factor": float(scaling_factor), + } + return config + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], + beta_schedule="scaled_linear", + ) + return schedular + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion + conversion, this function additionally converts the learnt film embedding linear layer. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"] + new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +CLAP_KEYS_TO_MODIFY_MAPPING = { + "text_branch": "text_model", + "audio_branch": "audio_model.audio_encoder", + "attn": "attention.self", + "self.proj": "output.dense", + "attention.self_mask": "attn_mask", + "mlp.fc1": "intermediate.dense", + "mlp.fc2": "output.dense", + "norm1": "layernorm_before", + "norm2": "layernorm_after", + "bn0": "batch_norm", +} + +CLAP_KEYS_TO_IGNORE = [ + "text_transform", + "audio_transform", + "stft", + "logmel_extractor", + "tscam_conv", + "head", + "attn_mask", +] + +CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"] + + +def convert_open_clap_checkpoint(checkpoint): + """ + Takes a state dict and returns a converted CLAP checkpoint. + """ + # extract state dict for CLAP text embedding model, discarding the audio component + model_state_dict = {} + model_key = "cond_stage_model.model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(model_key): + model_state_dict[key.replace(model_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + sequential_layers_pattern = r".*sequential.(\d+).*" + text_projection_pattern = r".*_projection.(\d+).*" + + for key, value in model_state_dict.items(): + # check if key should be ignored in mapping - if so map it to a key name that we'll filter out at the end + for key_to_ignore in CLAP_KEYS_TO_IGNORE: + if key_to_ignore in key: + key = "spectrogram" + + # check if any key needs to be modified + for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + if re.match(sequential_layers_pattern, key): + # replace sequential layers with list + sequential_layer = re.match(sequential_layers_pattern, key).group(1) + + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") + elif re.match(text_projection_pattern, key): + projecton_layer = int(re.match(text_projection_pattern, key).group(1)) + + # Because in CLAP they use `nn.Sequential`... + transformers_projection_layer = 1 if projecton_layer == 0 else 2 + + key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.") + + if "audio" and "qkv" in key: + # split qkv into query key and value + mixed_qkv = value + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + + new_checkpoint[key.replace("qkv", "query")] = query_layer + new_checkpoint[key.replace("qkv", "key")] = key_layer + new_checkpoint[key.replace("qkv", "value")] = value_layer + elif key != "spectrogram": + new_checkpoint[key] = value + + return new_checkpoint + + +def create_transformers_vocoder_config(original_config): + """ + Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. + """ + vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"] + + config = { + "model_in_dim": vocoder_params["num_mels"], + "sampling_rate": vocoder_params["sampling_rate"], + "upsample_initial_channel": vocoder_params["upsample_initial_channel"], + "upsample_rates": list(vocoder_params["upsample_rates"]), + "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]), + "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]), + "resblock_dilation_sizes": [ + list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"] + ], + "normalize_before": False, + } + + return config + + +def convert_hifigan_checkpoint(checkpoint, config): + """ + Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint. + """ + # extract state dict for vocoder + vocoder_state_dict = {} + vocoder_key = "first_stage_model.vocoder." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vocoder_key): + vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key) + + # fix upsampler keys, everything else is correct already + for i in range(len(config.upsample_rates)): + vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight") + vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias") + + if not config.normalize_before: + # if we don't set normalize_before then these variables are unused, so we set them to their initialised values + vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim) + vocoder_state_dict["scale"] = torch.ones(config.model_in_dim) + + return vocoder_state_dict + + +# Adapted from https://huggingface.co/spaces/haoheliu/MusicLDM-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/MusicLDM/utils.py#L72-L73 +DEFAULT_CONFIG = { + "model": { + "params": { + "linear_start": 0.0015, + "linear_end": 0.0195, + "timesteps": 1000, + "channels": 8, + "scale_by_std": True, + "unet_config": { + "target": "MusicLDM.latent_diffusion.openaimodel.UNetModel", + "params": { + "extra_film_condition_dim": 512, + "extra_film_use_concat": True, + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + }, + }, + "first_stage_config": { + "target": "MusicLDM.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "embed_dim": 8, + "ddconfig": { + "z_channels": 8, + "resolution": 256, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + }, + }, + }, + "vocoder_config": { + "target": "MusicLDM.first_stage_model.vocoder", + "params": { + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "num_mels": 64, + "sampling_rate": 16000, + }, + }, + }, + }, +} + + +def load_pipeline_from_original_MusicLDM_ckpt( + checkpoint_path: str, + original_config_file: str = None, + image_size: int = 1024, + prediction_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "ddim", + num_in_channels: int = None, + model_channels: int = None, + num_head_channels: int = None, + device: str = None, + from_safetensors: bool = False, +) -> MusicLDMPipeline: + """ + Load an MusicLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + set to the MusicLDM-s-full-v2 config. + image_size (`int`, *optional*, defaults to 1024): + The image size that the model was trained on. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. If `None`, will be automatically + inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`. + num_in_channels (`int`, *optional*, defaults to None): + The number of UNet input channels. If `None`, it will be automatically inferred from the config. + model_channels (`int`, *optional*, defaults to None): + The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override + to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large. + num_head_channels (`int`, *optional*, defaults to None): + The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override + to 32 for the small and medium checkpoints, and 64 for the large. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + original_config = DEFAULT_CONFIG + else: + original_config = yaml.safe_load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if model_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels + + if num_head_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + prediction_type = "v_prediction" + else: + if prediction_type is None: + prediction_type = "epsilon" + + if image_size is None: + image_size = 512 + + num_train_timesteps = original_config["model"]["params"]["timesteps"] + beta_start = original_config["model"]["params"]["linear_start"] + beta_end = original_config["model"]["params"]["linear_end"] + + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DModel + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet = UNet2DConditionModel(**unet_config) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model + vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + # Convert the text model + # MusicLDM uses the same tokenizer as the original CLAP model, but a slightly different configuration + config = ClapConfig.from_pretrained("laion/clap-htsat-unfused") + config.audio_config.update( + { + "patch_embeds_hidden_size": 128, + "hidden_size": 1024, + "depths": [2, 2, 12, 2], + } + ) + tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused") + + converted_text_model = convert_open_clap_checkpoint(checkpoint) + text_model = ClapModel(config) + + missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False) + # we expect not to have token_type_ids in our original state dict so let's ignore them + missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS)) + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}") + + if len(missing_keys) > 0: + raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}") + + # Convert the vocoder model + vocoder_config = create_transformers_vocoder_config(original_config) + vocoder_config = SpeechT5HifiGanConfig(**vocoder_config) + converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config) + + vocoder = SpeechT5HifiGan(vocoder_config) + vocoder.load_state_dict(converted_vocoder_checkpoint) + + # Instantiate the diffusers pipeline + pipe = MusicLDMPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + feature_extractor=feature_extractor, + ) + + return pipe + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--model_channels", + default=None, + type=int, + help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override" + " to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.", + ) + parser.add_argument( + "--num_head_channels", + default=None, + type=int, + help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override" + " to 32 for the small and medium checkpoints, and 64 for the large.", + ) + parser.add_argument( + "--scheduler_type", + default="ddim", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--image_size", + default=None, + type=int, + help=("The image size that the model was trained on."), + ) + parser.add_argument( + "--prediction_type", + default=None, + type=str, + help=("The prediction type that the model was trained on."), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + args = parser.parse_args() + + pipe = load_pipeline_from_original_MusicLDM_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + prediction_type=args.prediction_type, + extract_ema=args.extract_ema, + scheduler_type=args.scheduler_type, + num_in_channels=args.num_in_channels, + model_channels=args.model_channels, + num_head_channels=args.num_head_channels, + from_safetensors=args.from_safetensors, + device=args.device, + ) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py b/diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py new file mode 100755 index 0000000..7e7925b --- /dev/null +++ b/diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the LDM checkpoints.""" + +import argparse +import importlib + +import torch + +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--config_files", + default=None, + type=str, + help="The YAML config file corresponding to the architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--pipeline_type", + default=None, + type=str, + help=( + "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" + ". If `None` pipeline will be automatically inferred." + ), + ) + parser.add_argument( + "--image_size", + default=None, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--prediction_type", + default=None, + type=str, + help=( + "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" + " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + parser.add_argument( + "--stable_unclip", + type=str, + default=None, + required=False, + help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", + ) + parser.add_argument( + "--stable_unclip_prior", + type=str, + default=None, + required=False, + help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", + ) + parser.add_argument( + "--clip_stats_path", + type=str, + help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", + required=False, + ) + parser.add_argument( + "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." + ) + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--vae_path", + type=str, + default=None, + required=False, + help="Set to a path, hub id to an already converted vae to not convert it again.", + ) + parser.add_argument( + "--pipeline_class_name", + type=str, + default=None, + required=False, + help="Specify the pipeline class name", + ) + + args = parser.parse_args() + + if args.pipeline_class_name is not None: + library = importlib.import_module("diffusers") + class_obj = getattr(library, args.pipeline_class_name) + pipeline_class = class_obj + else: + pipeline_class = None + + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=args.checkpoint_path, + original_config_file=args.original_config_file, + config_files=args.config_files, + image_size=args.image_size, + prediction_type=args.prediction_type, + model_type=args.pipeline_type, + extract_ema=args.extract_ema, + scheduler_type=args.scheduler_type, + num_in_channels=args.num_in_channels, + upcast_attention=args.upcast_attention, + from_safetensors=args.from_safetensors, + device=args.device, + stable_unclip=args.stable_unclip, + stable_unclip_prior=args.stable_unclip_prior, + clip_stats_path=args.clip_stats_path, + controlnet=args.controlnet, + vae_path=args.vae_path, + pipeline_class=pipeline_class, + ) + + if args.half: + pipe.to(dtype=torch.float16) + + if args.controlnet: + # only save the controlnet model + pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) + else: + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/diffusers/scripts/convert_original_t2i_adapter.py b/diffusers/scripts/convert_original_t2i_adapter.py new file mode 100755 index 0000000..95c8817 --- /dev/null +++ b/diffusers/scripts/convert_original_t2i_adapter.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Conversion script for the T2I-Adapter checkpoints. +""" + +import argparse + +import torch + +from diffusers import T2IAdapter + + +def convert_adapter(src_state, in_channels): + original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1 + + assert original_body_length == 8 + + # (0, 1) -> channels 1 + assert src_state["body.0.block1.weight"].shape == (320, 320, 3, 3) + + # (2, 3) -> channels 2 + assert src_state["body.2.in_conv.weight"].shape == (640, 320, 1, 1) + + # (4, 5) -> channels 3 + assert src_state["body.4.in_conv.weight"].shape == (1280, 640, 1, 1) + + # (6, 7) -> channels 4 + assert src_state["body.6.block1.weight"].shape == (1280, 1280, 3, 3) + + res_state = { + "adapter.conv_in.weight": src_state.pop("conv_in.weight"), + "adapter.conv_in.bias": src_state.pop("conv_in.bias"), + # 0.resnets.0 + "adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.block1.weight"), + "adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.block1.bias"), + "adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.block2.weight"), + "adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.block2.bias"), + # 0.resnets.1 + "adapter.body.0.resnets.1.block1.weight": src_state.pop("body.1.block1.weight"), + "adapter.body.0.resnets.1.block1.bias": src_state.pop("body.1.block1.bias"), + "adapter.body.0.resnets.1.block2.weight": src_state.pop("body.1.block2.weight"), + "adapter.body.0.resnets.1.block2.bias": src_state.pop("body.1.block2.bias"), + # 1 + "adapter.body.1.in_conv.weight": src_state.pop("body.2.in_conv.weight"), + "adapter.body.1.in_conv.bias": src_state.pop("body.2.in_conv.bias"), + # 1.resnets.0 + "adapter.body.1.resnets.0.block1.weight": src_state.pop("body.2.block1.weight"), + "adapter.body.1.resnets.0.block1.bias": src_state.pop("body.2.block1.bias"), + "adapter.body.1.resnets.0.block2.weight": src_state.pop("body.2.block2.weight"), + "adapter.body.1.resnets.0.block2.bias": src_state.pop("body.2.block2.bias"), + # 1.resnets.1 + "adapter.body.1.resnets.1.block1.weight": src_state.pop("body.3.block1.weight"), + "adapter.body.1.resnets.1.block1.bias": src_state.pop("body.3.block1.bias"), + "adapter.body.1.resnets.1.block2.weight": src_state.pop("body.3.block2.weight"), + "adapter.body.1.resnets.1.block2.bias": src_state.pop("body.3.block2.bias"), + # 2 + "adapter.body.2.in_conv.weight": src_state.pop("body.4.in_conv.weight"), + "adapter.body.2.in_conv.bias": src_state.pop("body.4.in_conv.bias"), + # 2.resnets.0 + "adapter.body.2.resnets.0.block1.weight": src_state.pop("body.4.block1.weight"), + "adapter.body.2.resnets.0.block1.bias": src_state.pop("body.4.block1.bias"), + "adapter.body.2.resnets.0.block2.weight": src_state.pop("body.4.block2.weight"), + "adapter.body.2.resnets.0.block2.bias": src_state.pop("body.4.block2.bias"), + # 2.resnets.1 + "adapter.body.2.resnets.1.block1.weight": src_state.pop("body.5.block1.weight"), + "adapter.body.2.resnets.1.block1.bias": src_state.pop("body.5.block1.bias"), + "adapter.body.2.resnets.1.block2.weight": src_state.pop("body.5.block2.weight"), + "adapter.body.2.resnets.1.block2.bias": src_state.pop("body.5.block2.bias"), + # 3.resnets.0 + "adapter.body.3.resnets.0.block1.weight": src_state.pop("body.6.block1.weight"), + "adapter.body.3.resnets.0.block1.bias": src_state.pop("body.6.block1.bias"), + "adapter.body.3.resnets.0.block2.weight": src_state.pop("body.6.block2.weight"), + "adapter.body.3.resnets.0.block2.bias": src_state.pop("body.6.block2.bias"), + # 3.resnets.1 + "adapter.body.3.resnets.1.block1.weight": src_state.pop("body.7.block1.weight"), + "adapter.body.3.resnets.1.block1.bias": src_state.pop("body.7.block1.bias"), + "adapter.body.3.resnets.1.block2.weight": src_state.pop("body.7.block2.weight"), + "adapter.body.3.resnets.1.block2.bias": src_state.pop("body.7.block2.bias"), + } + + assert len(src_state) == 0 + + adapter = T2IAdapter(in_channels=in_channels, adapter_type="full_adapter") + + adapter.load_state_dict(res_state) + + return adapter + + +def convert_light_adapter(src_state): + original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1 + + assert original_body_length == 4 + + res_state = { + # body.0.in_conv + "adapter.body.0.in_conv.weight": src_state.pop("body.0.in_conv.weight"), + "adapter.body.0.in_conv.bias": src_state.pop("body.0.in_conv.bias"), + # body.0.resnets.0 + "adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.body.0.block1.weight"), + "adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.body.0.block1.bias"), + "adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.body.0.block2.weight"), + "adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.body.0.block2.bias"), + # body.0.resnets.1 + "adapter.body.0.resnets.1.block1.weight": src_state.pop("body.0.body.1.block1.weight"), + "adapter.body.0.resnets.1.block1.bias": src_state.pop("body.0.body.1.block1.bias"), + "adapter.body.0.resnets.1.block2.weight": src_state.pop("body.0.body.1.block2.weight"), + "adapter.body.0.resnets.1.block2.bias": src_state.pop("body.0.body.1.block2.bias"), + # body.0.resnets.2 + "adapter.body.0.resnets.2.block1.weight": src_state.pop("body.0.body.2.block1.weight"), + "adapter.body.0.resnets.2.block1.bias": src_state.pop("body.0.body.2.block1.bias"), + "adapter.body.0.resnets.2.block2.weight": src_state.pop("body.0.body.2.block2.weight"), + "adapter.body.0.resnets.2.block2.bias": src_state.pop("body.0.body.2.block2.bias"), + # body.0.resnets.3 + "adapter.body.0.resnets.3.block1.weight": src_state.pop("body.0.body.3.block1.weight"), + "adapter.body.0.resnets.3.block1.bias": src_state.pop("body.0.body.3.block1.bias"), + "adapter.body.0.resnets.3.block2.weight": src_state.pop("body.0.body.3.block2.weight"), + "adapter.body.0.resnets.3.block2.bias": src_state.pop("body.0.body.3.block2.bias"), + # body.0.out_conv + "adapter.body.0.out_conv.weight": src_state.pop("body.0.out_conv.weight"), + "adapter.body.0.out_conv.bias": src_state.pop("body.0.out_conv.bias"), + # body.1.in_conv + "adapter.body.1.in_conv.weight": src_state.pop("body.1.in_conv.weight"), + "adapter.body.1.in_conv.bias": src_state.pop("body.1.in_conv.bias"), + # body.1.resnets.0 + "adapter.body.1.resnets.0.block1.weight": src_state.pop("body.1.body.0.block1.weight"), + "adapter.body.1.resnets.0.block1.bias": src_state.pop("body.1.body.0.block1.bias"), + "adapter.body.1.resnets.0.block2.weight": src_state.pop("body.1.body.0.block2.weight"), + "adapter.body.1.resnets.0.block2.bias": src_state.pop("body.1.body.0.block2.bias"), + # body.1.resnets.1 + "adapter.body.1.resnets.1.block1.weight": src_state.pop("body.1.body.1.block1.weight"), + "adapter.body.1.resnets.1.block1.bias": src_state.pop("body.1.body.1.block1.bias"), + "adapter.body.1.resnets.1.block2.weight": src_state.pop("body.1.body.1.block2.weight"), + "adapter.body.1.resnets.1.block2.bias": src_state.pop("body.1.body.1.block2.bias"), + # body.1.body.2 + "adapter.body.1.resnets.2.block1.weight": src_state.pop("body.1.body.2.block1.weight"), + "adapter.body.1.resnets.2.block1.bias": src_state.pop("body.1.body.2.block1.bias"), + "adapter.body.1.resnets.2.block2.weight": src_state.pop("body.1.body.2.block2.weight"), + "adapter.body.1.resnets.2.block2.bias": src_state.pop("body.1.body.2.block2.bias"), + # body.1.body.3 + "adapter.body.1.resnets.3.block1.weight": src_state.pop("body.1.body.3.block1.weight"), + "adapter.body.1.resnets.3.block1.bias": src_state.pop("body.1.body.3.block1.bias"), + "adapter.body.1.resnets.3.block2.weight": src_state.pop("body.1.body.3.block2.weight"), + "adapter.body.1.resnets.3.block2.bias": src_state.pop("body.1.body.3.block2.bias"), + # body.1.out_conv + "adapter.body.1.out_conv.weight": src_state.pop("body.1.out_conv.weight"), + "adapter.body.1.out_conv.bias": src_state.pop("body.1.out_conv.bias"), + # body.2.in_conv + "adapter.body.2.in_conv.weight": src_state.pop("body.2.in_conv.weight"), + "adapter.body.2.in_conv.bias": src_state.pop("body.2.in_conv.bias"), + # body.2.body.0 + "adapter.body.2.resnets.0.block1.weight": src_state.pop("body.2.body.0.block1.weight"), + "adapter.body.2.resnets.0.block1.bias": src_state.pop("body.2.body.0.block1.bias"), + "adapter.body.2.resnets.0.block2.weight": src_state.pop("body.2.body.0.block2.weight"), + "adapter.body.2.resnets.0.block2.bias": src_state.pop("body.2.body.0.block2.bias"), + # body.2.body.1 + "adapter.body.2.resnets.1.block1.weight": src_state.pop("body.2.body.1.block1.weight"), + "adapter.body.2.resnets.1.block1.bias": src_state.pop("body.2.body.1.block1.bias"), + "adapter.body.2.resnets.1.block2.weight": src_state.pop("body.2.body.1.block2.weight"), + "adapter.body.2.resnets.1.block2.bias": src_state.pop("body.2.body.1.block2.bias"), + # body.2.body.2 + "adapter.body.2.resnets.2.block1.weight": src_state.pop("body.2.body.2.block1.weight"), + "adapter.body.2.resnets.2.block1.bias": src_state.pop("body.2.body.2.block1.bias"), + "adapter.body.2.resnets.2.block2.weight": src_state.pop("body.2.body.2.block2.weight"), + "adapter.body.2.resnets.2.block2.bias": src_state.pop("body.2.body.2.block2.bias"), + # body.2.body.3 + "adapter.body.2.resnets.3.block1.weight": src_state.pop("body.2.body.3.block1.weight"), + "adapter.body.2.resnets.3.block1.bias": src_state.pop("body.2.body.3.block1.bias"), + "adapter.body.2.resnets.3.block2.weight": src_state.pop("body.2.body.3.block2.weight"), + "adapter.body.2.resnets.3.block2.bias": src_state.pop("body.2.body.3.block2.bias"), + # body.2.out_conv + "adapter.body.2.out_conv.weight": src_state.pop("body.2.out_conv.weight"), + "adapter.body.2.out_conv.bias": src_state.pop("body.2.out_conv.bias"), + # body.3.in_conv + "adapter.body.3.in_conv.weight": src_state.pop("body.3.in_conv.weight"), + "adapter.body.3.in_conv.bias": src_state.pop("body.3.in_conv.bias"), + # body.3.body.0 + "adapter.body.3.resnets.0.block1.weight": src_state.pop("body.3.body.0.block1.weight"), + "adapter.body.3.resnets.0.block1.bias": src_state.pop("body.3.body.0.block1.bias"), + "adapter.body.3.resnets.0.block2.weight": src_state.pop("body.3.body.0.block2.weight"), + "adapter.body.3.resnets.0.block2.bias": src_state.pop("body.3.body.0.block2.bias"), + # body.3.body.1 + "adapter.body.3.resnets.1.block1.weight": src_state.pop("body.3.body.1.block1.weight"), + "adapter.body.3.resnets.1.block1.bias": src_state.pop("body.3.body.1.block1.bias"), + "adapter.body.3.resnets.1.block2.weight": src_state.pop("body.3.body.1.block2.weight"), + "adapter.body.3.resnets.1.block2.bias": src_state.pop("body.3.body.1.block2.bias"), + # body.3.body.2 + "adapter.body.3.resnets.2.block1.weight": src_state.pop("body.3.body.2.block1.weight"), + "adapter.body.3.resnets.2.block1.bias": src_state.pop("body.3.body.2.block1.bias"), + "adapter.body.3.resnets.2.block2.weight": src_state.pop("body.3.body.2.block2.weight"), + "adapter.body.3.resnets.2.block2.bias": src_state.pop("body.3.body.2.block2.bias"), + # body.3.body.3 + "adapter.body.3.resnets.3.block1.weight": src_state.pop("body.3.body.3.block1.weight"), + "adapter.body.3.resnets.3.block1.bias": src_state.pop("body.3.body.3.block1.bias"), + "adapter.body.3.resnets.3.block2.weight": src_state.pop("body.3.body.3.block2.weight"), + "adapter.body.3.resnets.3.block2.bias": src_state.pop("body.3.body.3.block2.bias"), + # body.3.out_conv + "adapter.body.3.out_conv.weight": src_state.pop("body.3.out_conv.weight"), + "adapter.body.3.out_conv.bias": src_state.pop("body.3.out_conv.bias"), + } + + assert len(src_state) == 0 + + adapter = T2IAdapter(in_channels=3, channels=[320, 640, 1280], num_res_blocks=4, adapter_type="light_adapter") + + adapter.load_state_dict(res_state) + + return adapter + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--output_path", default=None, type=str, required=True, help="Path to the store the result checkpoint." + ) + parser.add_argument( + "--is_adapter_light", + action="store_true", + help="Is checkpoint come from Adapter-Light architecture. ex: color-adapter", + ) + parser.add_argument("--in_channels", required=False, type=int, help="Input channels for non-light adapter") + + args = parser.parse_args() + src_state = torch.load(args.checkpoint_path) + + if args.is_adapter_light: + adapter = convert_light_adapter(src_state) + else: + if args.in_channels is None: + raise ValueError("set `--in_channels=`") + adapter = convert_adapter(src_state, args.in_channels) + + adapter.save_pretrained(args.output_path) diff --git a/diffusers/scripts/convert_pixart_alpha_to_diffusers.py b/diffusers/scripts/convert_pixart_alpha_to_diffusers.py new file mode 100755 index 0000000..228b479 --- /dev/null +++ b/diffusers/scripts/convert_pixart_alpha_to_diffusers.py @@ -0,0 +1,198 @@ +import argparse +import os + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel + + +ckpt_id = "PixArt-alpha/PixArt-alpha" +# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125 +interpolation_scale = {256: 0.5, 512: 1, 1024: 2} + + +def main(args): + all_state_dict = torch.load(args.orig_ckpt_path, map_location="cpu") + state_dict = all_state_dict.pop("state_dict") + converted_state_dict = {} + + # Patch embeddings. + converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") + + # Caption projection. + converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") + + # AdaLN-single LN + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + + if args.image_size == 1024: + # Resolution. + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop( + "csize_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop( + "csize_embedder.mlp.0.bias" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop( + "csize_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop( + "csize_embedder.mlp.2.bias" + ) + # Aspect ratio. + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop( + "ar_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop( + "ar_embedder.mlp.0.bias" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop( + "ar_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop( + "ar_embedder.mlp.2.bias" + ) + # Shared norm. + converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") + converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") + + for depth in range(28): + # Transformer blocks. + converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( + f"blocks.{depth}.scale_shift_table" + ) + + # Attention is all you need 🤘 + + # Self attention. + q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias + # Projection. + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.attn.proj.bias" + ) + + # Feed-forward. + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.fc2.bias" + ) + + # Cross-attention. + q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") + q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") + k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) + k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.bias" + ) + + # Final block. + converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") + + # DiT XL/2 + transformer = Transformer2DModel( + sample_size=args.image_size // 8, + num_layers=28, + attention_head_dim=72, + in_channels=4, + out_channels=8, + patch_size=2, + attention_bias=True, + num_attention_heads=16, + cross_attention_dim=1152, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + caption_channels=4096, + ) + transformer.load_state_dict(converted_state_dict, strict=True) + + assert transformer.pos_embed.pos_embed is not None + state_dict.pop("pos_embed") + state_dict.pop("y_embedder.y_embedding") + assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" + + num_model_params = sum(p.numel() for p in transformer.parameters()) + print(f"Total number of transformer parameters: {num_model_params}") + + if args.only_transformer: + transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) + else: + scheduler = DPMSolverMultistepScheduler() + + vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="sd-vae-ft-ema") + + tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl") + text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl") + + pipeline = PixArtAlphaPipeline( + tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler + ) + + pipeline.save_pretrained(args.dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--image_size", + default=1024, + type=int, + choices=[256, 512, 1024], + required=False, + help="Image size of pretrained model, either 512 or 1024.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--only_transformer", default=True, type=bool, required=True) + + args = parser.parse_args() + main(args) diff --git a/diffusers/scripts/convert_pixart_sigma_to_diffusers.py b/diffusers/scripts/convert_pixart_sigma_to_diffusers.py new file mode 100755 index 0000000..9572a83 --- /dev/null +++ b/diffusers/scripts/convert_pixart_sigma_to_diffusers.py @@ -0,0 +1,223 @@ +import argparse +import os + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtSigmaPipeline, Transformer2DModel + + +ckpt_id = "PixArt-alpha" +# https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158 +interpolation_scale = {256: 0.5, 512: 1, 1024: 2, 2048: 4} + + +def main(args): + all_state_dict = torch.load(args.orig_ckpt_path) + state_dict = all_state_dict.pop("state_dict") + converted_state_dict = {} + + # Patch embeddings. + converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") + + # Caption projection. + converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") + + # AdaLN-single LN + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + + if args.micro_condition: + # Resolution. + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop( + "csize_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop( + "csize_embedder.mlp.0.bias" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop( + "csize_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop( + "csize_embedder.mlp.2.bias" + ) + # Aspect ratio. + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop( + "ar_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop( + "ar_embedder.mlp.0.bias" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop( + "ar_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop( + "ar_embedder.mlp.2.bias" + ) + # Shared norm. + converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") + converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") + + for depth in range(28): + # Transformer blocks. + converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( + f"blocks.{depth}.scale_shift_table" + ) + # Attention is all you need 🤘 + + # Self attention. + q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias + # Projection. + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.attn.proj.bias" + ) + if args.qk_norm: + converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.bias" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.bias" + ) + + # Feed-forward. + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.fc2.bias" + ) + + # Cross-attention. + q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") + q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") + k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) + k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.bias" + ) + + # Final block. + converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") + + # PixArt XL/2 + transformer = Transformer2DModel( + sample_size=args.image_size // 8, + num_layers=28, + attention_head_dim=72, + in_channels=4, + out_channels=8, + patch_size=2, + attention_bias=True, + num_attention_heads=16, + cross_attention_dim=1152, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + caption_channels=4096, + interpolation_scale=interpolation_scale[args.image_size], + use_additional_conditions=args.micro_condition, + ) + transformer.load_state_dict(converted_state_dict, strict=True) + + assert transformer.pos_embed.pos_embed is not None + try: + state_dict.pop("y_embedder.y_embedding") + state_dict.pop("pos_embed") + except Exception as e: + print(f"Skipping {str(e)}") + pass + assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" + + num_model_params = sum(p.numel() for p in transformer.parameters()) + print(f"Total number of transformer parameters: {num_model_params}") + + if args.only_transformer: + transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) + else: + # pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae + vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae") + + scheduler = DPMSolverMultistepScheduler() + + tokenizer = T5Tokenizer.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="text_encoder" + ) + + pipeline = PixArtSigmaPipeline( + tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler + ) + + pipeline.save_pretrained(args.dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--micro_condition", action="store_true", help="If use Micro-condition in PixArtMS structure during training." + ) + parser.add_argument("--qk_norm", action="store_true", help="If use qk norm during training.") + parser.add_argument( + "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--image_size", + default=1024, + type=int, + choices=[256, 512, 1024, 2048], + required=False, + help="Image size of pretrained model, 256, 512, 1024, or 2048.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--only_transformer", default=True, type=bool, required=True) + + args = parser.parse_args() + main(args) diff --git a/diffusers/scripts/convert_sd3_to_diffusers.py b/diffusers/scripts/convert_sd3_to_diffusers.py new file mode 100755 index 0000000..4f32745 --- /dev/null +++ b/diffusers/scripts/convert_sd3_to_diffusers.py @@ -0,0 +1,248 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +import torch +from accelerate import init_empty_weights + +from diffusers import AutoencoderKL, SD3Transformer2DModel +from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint +from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--checkpoint_path", type=str) +parser.add_argument("--output_path", type=str) +parser.add_argument("--dtype", type=str, default="fp16") + +args = parser.parse_args() +dtype = torch.float16 if args.dtype == "fp16" else torch.float32 + + +def load_original_checkpoint(ckpt_path): + original_state_dict = safetensors.torch.load_file(ckpt_path) + keys = list(original_state_dict.keys()) + for k in keys: + if "model.diffusion_model." in k: + original_state_dict[k.replace("model.diffusion_model.", "")] = original_state_dict.pop(k) + + return original_state_dict + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim): + converted_state_dict = {} + + # Positional and patch embeddings. + converted_state_dict["pos_embed.pos_embed"] = original_state_dict.pop("pos_embed") + converted_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "t_embedder.mlp.0.bias" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "t_embedder.mlp.2.bias" + ) + + # Context projections. + converted_state_dict["context_embedder.weight"] = original_state_dict.pop("context_embedder.weight") + converted_state_dict["context_embedder.bias"] = original_state_dict.pop("context_embedder.bias") + + # Pooled context projection. + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( + "y_embedder.mlp.0.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( + "y_embedder.mlp.0.bias" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( + "y_embedder.mlp.2.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( + "y_embedder.mlp.2.bias" + ) + + # Transformer blocks 🎸. + for i in range(num_layers): + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0 + ) + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0 + ) + + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias]) + + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.proj.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.proj.bias" + ) + + # norms. + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" + ) + else: + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift( + original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), + dim=caption_projection_dim, + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift( + original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), + dim=caption_projection_dim, + ) + + # ffs. + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.bias" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim + ) + + return converted_state_dict + + +def is_vae_in_checkpoint(original_state_dict): + return ("first_stage_model.decoder.conv_in.weight" in original_state_dict) and ( + "first_stage_model.encoder.conv_in.weight" in original_state_dict + ) + + +def main(args): + original_ckpt = load_original_checkpoint(args.checkpoint_path) + num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 + caption_projection_dim = 1536 + + converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( + original_ckpt, num_layers, caption_projection_dim + ) + + with CTX(): + transformer = SD3Transformer2DModel( + sample_size=64, + patch_size=2, + in_channels=16, + joint_attention_dim=4096, + num_layers=num_layers, + caption_projection_dim=caption_projection_dim, + num_attention_heads=24, + pos_embed_max_size=192, + ) + if is_accelerate_available(): + load_model_dict_into_meta(transformer, converted_transformer_state_dict) + else: + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + + print("Saving SD3 Transformer in Diffusers format.") + transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") + + if is_vae_in_checkpoint(original_ckpt): + with CTX(): + vae = AutoencoderKL.from_config( + "stabilityai/stable-diffusion-xl-base-1.0", + subfolder="vae", + latent_channels=16, + use_post_quant_conv=False, + use_quant_conv=False, + scaling_factor=1.5305, + shift_factor=0.0609, + ) + converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config) + if is_accelerate_available(): + load_model_dict_into_meta(vae, converted_vae_state_dict) + else: + vae.load_state_dict(converted_vae_state_dict, strict=True) + + print("Saving SD3 Autoencoder in Diffusers format.") + vae.to(dtype).save_pretrained(f"{args.output_path}/vae") + + +if __name__ == "__main__": + main(args) diff --git a/diffusers/scripts/convert_shap_e_to_diffusers.py b/diffusers/scripts/convert_shap_e_to_diffusers.py new file mode 100755 index 0000000..b903b4e --- /dev/null +++ b/diffusers/scripts/convert_shap_e_to_diffusers.py @@ -0,0 +1,1080 @@ +import argparse +import tempfile + +import torch +from accelerate import load_checkpoint_and_dispatch + +from diffusers.models.transformers.prior_transformer import PriorTransformer +from diffusers.pipelines.shap_e import ShapERenderer + + +""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt" +``` + +Convert the model: +```sh +$ python scripts/convert_shap_e_to_diffusers.py \ + --prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \ + --prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \ + --transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\ + --dump_path /home/yiyi_huggingface_co/model_repo/shap-e-img2img/shap_e_renderer\ + --debug renderer +``` +""" + + +# prior + +PRIOR_ORIGINAL_PREFIX = "wrapped" + +PRIOR_CONFIG = { + "num_attention_heads": 16, + "attention_head_dim": 1024 // 16, + "num_layers": 24, + "embedding_dim": 1024, + "num_embeddings": 1024, + "additional_embeddings": 0, + "time_embed_act_fn": "gelu", + "norm_in_type": "layer", + "encoder_hid_proj_type": None, + "added_emb_type": None, + "time_embed_dim": 1024 * 4, + "embedding_proj_dim": 768, + "clip_embed_dim": 1024 * 2, +} + + +def prior_model_from_original_config(): + model = PriorTransformer(**PRIOR_CONFIG) + + return model + + +def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # .time_embed.c_fc -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"], + } + ) + + # .time_embed.c_proj -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"], + } + ) + + # .input_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"], + } + ) + + # .clip_emb -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"], + } + ) + + # .pos_emb -> .positional_embedding + diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]}) + + # .ln_pre -> .norm_in + diffusers_checkpoint.update( + { + "norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"], + "norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"], + } + ) + + # .backbone.resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .ln_post -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"], + } + ) + + # .output_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_attention_to_diffusers( + checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim +): + diffusers_checkpoint = {} + + # .c_qkv -> .{to_q, to_k, to_v} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], + bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], + split=3, + chunk_size=attention_head_dim, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .c_proj -> .to_out.0 + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): + diffusers_checkpoint = { + # .c_fc -> .net.0.proj + f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], + f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], + # .c_proj -> .net.2 + f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], + f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], + } + + return diffusers_checkpoint + + +# done prior + + +# prior_image (only slightly different from prior) + + +PRIOR_IMAGE_ORIGINAL_PREFIX = "wrapped" + +# Uses default arguments +PRIOR_IMAGE_CONFIG = { + "num_attention_heads": 8, + "attention_head_dim": 1024 // 8, + "num_layers": 24, + "embedding_dim": 1024, + "num_embeddings": 1024, + "additional_embeddings": 0, + "time_embed_act_fn": "gelu", + "norm_in_type": "layer", + "embedding_proj_norm_type": "layer", + "encoder_hid_proj_type": None, + "added_emb_type": None, + "time_embed_dim": 1024 * 4, + "embedding_proj_dim": 1024, + "clip_embed_dim": 1024 * 2, +} + + +def prior_image_model_from_original_config(): + model = PriorTransformer(**PRIOR_IMAGE_CONFIG) + + return model + + +def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # .time_embed.c_fc -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.bias"], + } + ) + + # .time_embed.c_proj -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.bias"], + } + ) + + # .input_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.bias"], + } + ) + + # .clip_embed.0 -> .embedding_proj_norm + diffusers_checkpoint.update( + { + "embedding_proj_norm.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.weight"], + "embedding_proj_norm.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.bias"], + } + ) + + # ..clip_embed.1 -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.bias"], + } + ) + + # .pos_emb -> .positional_embedding + diffusers_checkpoint.update( + {"positional_embedding": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.pos_emb"][None, :]} + ) + + # .ln_pre -> .norm_in + diffusers_checkpoint.update( + { + "norm_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.weight"], + "norm_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.bias"], + } + ) + + # .backbone.resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.backbone.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .ln_post -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.bias"], + } + ) + + # .output_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.bias"], + } + ) + + return diffusers_checkpoint + + +# done prior_image + + +# renderer + +## create the lookup table for marching cubes method used in MeshDecoder + +MC_TABLE = [ + [], + [[0, 1, 0, 2, 0, 4]], + [[1, 0, 1, 5, 1, 3]], + [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]], + [[2, 0, 2, 3, 2, 6]], + [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]], + [[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]], + [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]], + [[3, 1, 3, 7, 3, 2]], + [[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]], + [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]], + [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]], + [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]], + [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]], + [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]], + [[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]], + [[4, 0, 4, 6, 4, 5]], + [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]], + [[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]], + [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]], + [[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]], + [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]], + [[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]], + [[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]], + [[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]], + [[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]], + [[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]], + [[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]], + [[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]], + [[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]], + [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]], + [[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]], + [[5, 1, 5, 4, 5, 7]], + [[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]], + [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]], + [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]], + [[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]], + [[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]], + [[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]], + [[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]], + [[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]], + [[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]], + [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]], + [[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]], + [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]], + [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]], + [[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]], + [[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]], + [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]], + [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]], + [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]], + [[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]], + [[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]], + [[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]], + [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]], + [[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]], + [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]], + [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]], + [[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]], + [[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]], + [[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]], + [ + [3, 1, 7, 3, 6, 2], + [6, 2, 0, 1, 3, 1], + [6, 4, 0, 1, 6, 2], + [6, 4, 5, 1, 0, 1], + [6, 4, 7, 5, 5, 1], + ], + [ + [4, 0, 6, 4, 7, 5], + [7, 5, 1, 0, 4, 0], + [7, 3, 1, 0, 7, 5], + [7, 3, 2, 0, 1, 0], + [7, 3, 6, 2, 2, 0], + ], + [[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]], + [[6, 2, 6, 7, 6, 4]], + [[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]], + [[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]], + [[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]], + [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]], + [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]], + [[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]], + [[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]], + [[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]], + [[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]], + [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]], + [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]], + [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]], + [[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]], + [[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]], + [[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]], + [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]], + [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]], + [[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]], + [[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]], + [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]], + [[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]], + [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]], + [[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]], + [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]], + [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]], + [[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]], + [ + [6, 2, 7, 6, 5, 4], + [5, 4, 0, 2, 6, 2], + [5, 1, 0, 2, 5, 4], + [5, 1, 3, 2, 0, 2], + [5, 1, 7, 3, 3, 2], + ], + [[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]], + [[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]], + [ + [1, 0, 5, 1, 7, 3], + [7, 3, 2, 0, 1, 0], + [7, 6, 2, 0, 7, 3], + [7, 6, 4, 0, 2, 0], + [7, 6, 5, 4, 4, 0], + ], + [[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]], + [[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]], + [[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]], + [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]], + [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]], + [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]], + [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]], + [[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]], + [ + [5, 4, 7, 5, 3, 1], + [3, 1, 0, 4, 5, 4], + [3, 2, 0, 4, 3, 1], + [3, 2, 6, 4, 0, 4], + [3, 2, 7, 6, 6, 4], + ], + [[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]], + [[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]], + [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]], + [ + [0, 4, 2, 3, 0, 2], + [0, 4, 3, 7, 2, 3], + [0, 4, 4, 5, 3, 7], + [4, 5, 5, 7, 3, 7], + [6, 7, 4, 6, 2, 6], + ], + [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]], + [ + [0, 1, 4, 6, 0, 4], + [0, 1, 6, 7, 4, 6], + [0, 1, 1, 3, 6, 7], + [1, 3, 3, 7, 6, 7], + [5, 7, 1, 5, 4, 5], + ], + [ + [6, 7, 4, 6, 0, 2], + [0, 2, 3, 7, 6, 7], + [0, 1, 3, 7, 0, 2], + [0, 1, 5, 7, 3, 7], + [0, 1, 4, 5, 5, 7], + ], + [[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]], + [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]], + [[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]], + [[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]], + [[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]], + [[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]], + [[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]], + [ + [2, 0, 3, 2, 7, 6], + [7, 6, 4, 0, 2, 0], + [7, 5, 4, 0, 7, 6], + [7, 5, 1, 0, 4, 0], + [7, 5, 3, 1, 1, 0], + ], + [[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]], + [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]], + [ + [0, 2, 1, 5, 0, 1], + [0, 2, 5, 7, 1, 5], + [0, 2, 2, 6, 5, 7], + [2, 6, 6, 7, 5, 7], + [3, 7, 2, 3, 1, 3], + ], + [ + [3, 7, 2, 3, 0, 1], + [0, 1, 5, 7, 3, 7], + [0, 4, 5, 7, 0, 1], + [0, 4, 6, 7, 5, 7], + [0, 4, 2, 6, 6, 7], + ], + [[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]], + [ + [5, 7, 1, 5, 0, 4], + [0, 4, 6, 7, 5, 7], + [0, 2, 6, 7, 0, 4], + [0, 2, 3, 7, 6, 7], + [0, 2, 1, 3, 3, 7], + ], + [[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]], + [[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]], + [[7, 5, 7, 3, 7, 6]], + [[7, 3, 7, 5, 7, 6]], + [[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]], + [[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]], + [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]], + [[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]], + [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]], + [[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]], + [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]], + [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]], + [[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]], + [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]], + [[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]], + [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]], + [[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]], + [[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]], + [[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]], + [[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]], + [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]], + [[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]], + [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]], + [[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]], + [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]], + [[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]], + [ + [1, 3, 5, 4, 1, 5], + [1, 3, 4, 6, 5, 4], + [1, 3, 3, 2, 4, 6], + [3, 2, 2, 6, 4, 6], + [7, 6, 3, 7, 5, 7], + ], + [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]], + [[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]], + [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]], + [ + [7, 5, 6, 7, 2, 3], + [2, 3, 1, 5, 7, 5], + [2, 0, 1, 5, 2, 3], + [2, 0, 4, 5, 1, 5], + [2, 0, 6, 4, 4, 5], + ], + [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]], + [ + [4, 6, 5, 4, 1, 0], + [1, 0, 2, 6, 4, 6], + [1, 3, 2, 6, 1, 0], + [1, 3, 7, 6, 2, 6], + [1, 3, 5, 7, 7, 6], + ], + [ + [1, 5, 0, 2, 1, 0], + [1, 5, 2, 6, 0, 2], + [1, 5, 5, 7, 2, 6], + [5, 7, 7, 6, 2, 6], + [4, 6, 5, 4, 0, 4], + ], + [[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]], + [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]], + [[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]], + [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]], + [[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]], + [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]], + [[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]], + [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]], + [ + [2, 3, 6, 2, 4, 0], + [4, 0, 1, 3, 2, 3], + [4, 5, 1, 3, 4, 0], + [4, 5, 7, 3, 1, 3], + [4, 5, 6, 7, 7, 3], + ], + [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]], + [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]], + [[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]], + [[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]], + [[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]], + [ + [5, 1, 4, 5, 6, 7], + [6, 7, 3, 1, 5, 1], + [6, 2, 3, 1, 6, 7], + [6, 2, 0, 1, 3, 1], + [6, 2, 4, 0, 0, 1], + ], + [[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]], + [[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]], + [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]], + [[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]], + [[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]], + [[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]], + [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]], + [ + [7, 6, 3, 7, 1, 5], + [1, 5, 4, 6, 7, 6], + [1, 0, 4, 6, 1, 5], + [1, 0, 2, 6, 4, 6], + [1, 0, 3, 2, 2, 6], + ], + [ + [1, 0, 3, 7, 1, 3], + [1, 0, 7, 6, 3, 7], + [1, 0, 0, 4, 7, 6], + [0, 4, 4, 6, 7, 6], + [2, 6, 0, 2, 3, 2], + ], + [[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]], + [[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]], + [ + [0, 1, 2, 0, 6, 4], + [6, 4, 5, 1, 0, 1], + [6, 7, 5, 1, 6, 4], + [6, 7, 3, 1, 5, 1], + [6, 7, 2, 3, 3, 1], + ], + [[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]], + [[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]], + [ + [2, 6, 0, 2, 1, 3], + [1, 3, 7, 6, 2, 6], + [1, 5, 7, 6, 1, 3], + [1, 5, 4, 6, 7, 6], + [1, 5, 0, 4, 4, 6], + ], + [[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]], + [[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]], + [[6, 7, 6, 2, 6, 4]], + [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]], + [[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]], + [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]], + [[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]], + [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]], + [[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]], + [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]], + [ + [7, 3, 5, 7, 4, 6], + [4, 6, 2, 3, 7, 3], + [4, 0, 2, 3, 4, 6], + [4, 0, 1, 3, 2, 3], + [4, 0, 5, 1, 1, 3], + ], + [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]], + [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]], + [[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]], + [ + [0, 2, 4, 0, 5, 1], + [5, 1, 3, 2, 0, 2], + [5, 7, 3, 2, 5, 1], + [5, 7, 6, 2, 3, 2], + [5, 7, 4, 6, 6, 2], + ], + [[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]], + [[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]], + [[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]], + [[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]], + [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]], + [[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]], + [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]], + [ + [1, 5, 3, 1, 2, 0], + [2, 0, 4, 5, 1, 5], + [2, 6, 4, 5, 2, 0], + [2, 6, 7, 5, 4, 5], + [2, 6, 3, 7, 7, 5], + ], + [[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]], + [[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]], + [ + [2, 3, 0, 4, 2, 0], + [2, 3, 4, 5, 0, 4], + [2, 3, 3, 7, 4, 5], + [3, 7, 7, 5, 4, 5], + [1, 5, 3, 1, 0, 1], + ], + [[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]], + [[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]], + [ + [3, 2, 1, 3, 5, 7], + [5, 7, 6, 2, 3, 2], + [5, 4, 6, 2, 5, 7], + [5, 4, 0, 2, 6, 2], + [5, 4, 1, 0, 0, 2], + ], + [ + [4, 5, 0, 4, 2, 6], + [2, 6, 7, 5, 4, 5], + [2, 3, 7, 5, 2, 6], + [2, 3, 1, 5, 7, 5], + [2, 3, 0, 1, 1, 5], + ], + [[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]], + [[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]], + [[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]], + [[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]], + [[5, 4, 5, 1, 5, 7]], + [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]], + [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]], + [[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]], + [ + [6, 4, 2, 6, 3, 7], + [3, 7, 5, 4, 6, 4], + [3, 1, 5, 4, 3, 7], + [3, 1, 0, 4, 5, 4], + [3, 1, 2, 0, 0, 4], + ], + [[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]], + [ + [0, 4, 1, 0, 3, 2], + [3, 2, 6, 4, 0, 4], + [3, 7, 6, 4, 3, 2], + [3, 7, 5, 4, 6, 4], + [3, 7, 1, 5, 5, 4], + ], + [ + [1, 3, 0, 1, 4, 5], + [4, 5, 7, 3, 1, 3], + [4, 6, 7, 3, 4, 5], + [4, 6, 2, 3, 7, 3], + [4, 6, 0, 2, 2, 3], + ], + [[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]], + [[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]], + [ + [3, 1, 2, 6, 3, 2], + [3, 1, 6, 4, 2, 6], + [3, 1, 1, 5, 6, 4], + [1, 5, 5, 4, 6, 4], + [0, 4, 1, 0, 2, 0], + ], + [[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]], + [[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]], + [[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]], + [[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]], + [[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]], + [[4, 6, 4, 0, 4, 5]], + [[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]], + [[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]], + [[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]], + [[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]], + [[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]], + [[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]], + [[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]], + [[3, 7, 3, 1, 3, 2]], + [[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]], + [[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]], + [[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]], + [[2, 3, 2, 0, 2, 6]], + [[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]], + [[1, 5, 1, 0, 1, 3]], + [[0, 2, 0, 1, 0, 4]], + [], +] + + +def create_mc_lookup_table(): + cases = torch.zeros(256, 5, 3, dtype=torch.long) + masks = torch.zeros(256, 5, dtype=torch.bool) + + edge_to_index = { + (0, 1): 0, + (2, 3): 1, + (4, 5): 2, + (6, 7): 3, + (0, 2): 4, + (1, 3): 5, + (4, 6): 6, + (5, 7): 7, + (0, 4): 8, + (1, 5): 9, + (2, 6): 10, + (3, 7): 11, + } + + for i, case in enumerate(MC_TABLE): + for j, tri in enumerate(case): + for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])): + cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)] + masks[i, j] = True + return cases, masks + + +RENDERER_CONFIG = {} + + +def renderer_model_from_original_config(): + model = ShapERenderer(**RENDERER_CONFIG) + + return model + + +RENDERER_MLP_ORIGINAL_PREFIX = "renderer.nerstf" + +RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj" + + +def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + diffusers_checkpoint.update( + {f"mlp.{k}": checkpoint[f"{RENDERER_MLP_ORIGINAL_PREFIX}.{k}"] for k in model.mlp.state_dict().keys()} + ) + + diffusers_checkpoint.update( + { + f"params_proj.{k}": checkpoint[f"{RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"] + for k in model.params_proj.state_dict().keys() + } + ) + + diffusers_checkpoint.update({"void.background": model.state_dict()["void.background"]}) + + cases, masks = create_mc_lookup_table() + + diffusers_checkpoint.update({"mesh_decoder.cases": cases}) + diffusers_checkpoint.update({"mesh_decoder.masks": masks}) + + return diffusers_checkpoint + + +# done renderer + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + assert weights[weights_biases_idx] is None + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +# done unet utils + + +# Driver functions + + +def prior(*, args, checkpoint_map_location): + print("loading prior") + + prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) + + prior_model = prior_model_from_original_config() + + prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint) + + del prior_checkpoint + + load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model) + + print("done loading prior") + + return prior_model + + +def prior_image(*, args, checkpoint_map_location): + print("loading prior_image") + + print(f"load checkpoint from {args.prior_image_checkpoint_path}") + prior_checkpoint = torch.load(args.prior_image_checkpoint_path, map_location=checkpoint_map_location) + + prior_model = prior_image_model_from_original_config() + + prior_diffusers_checkpoint = prior_image_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint) + + del prior_checkpoint + + load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model) + + print("done loading prior_image") + + return prior_model + + +def renderer(*, args, checkpoint_map_location): + print(" loading renderer") + + renderer_checkpoint = torch.load(args.transmitter_checkpoint_path, map_location=checkpoint_map_location) + + renderer_model = renderer_model_from_original_config() + + renderer_diffusers_checkpoint = renderer_model_original_checkpoint_to_diffusers_checkpoint( + renderer_model, renderer_checkpoint + ) + + del renderer_checkpoint + + load_checkpoint_to_model(renderer_diffusers_checkpoint, renderer_model, strict=True) + + print("done loading renderer") + + return renderer_model + + +# prior model will expect clip_mean and clip_std, whic are missing from the state_dict +PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"] + + +def load_prior_checkpoint_to_model(checkpoint, model): + with tempfile.NamedTemporaryFile() as file: + torch.save(checkpoint, file.name) + del checkpoint + missing_keys, unexpected_keys = model.load_state_dict(torch.load(file.name), strict=False) + missing_keys = list(set(missing_keys) - set(PRIOR_EXPECTED_MISSING_KEYS)) + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected keys when loading prior model: {unexpected_keys}") + if len(missing_keys) > 0: + raise ValueError(f"Missing keys when loading prior model: {missing_keys}") + + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile() as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--prior_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the prior checkpoint to convert.", + ) + + parser.add_argument( + "--prior_image_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the prior_image checkpoint to convert.", + ) + + parser.add_argument( + "--transmitter_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the transmitter checkpoint to convert.", + ) + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + if args.debug is not None: + print(f"debug: only executing {args.debug}") + + if args.debug is None: + print("YiYi TO-DO") + elif args.debug == "prior": + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + prior_model.save_pretrained(args.dump_path) + elif args.debug == "prior_image": + prior_model = prior_image(args=args, checkpoint_map_location=checkpoint_map_location) + prior_model.save_pretrained(args.dump_path) + elif args.debug == "renderer": + renderer_model = renderer(args=args, checkpoint_map_location=checkpoint_map_location) + renderer_model.save_pretrained(args.dump_path) + else: + raise ValueError(f"unknown debug value : {args.debug}") diff --git a/diffusers/scripts/convert_stable_audio.py b/diffusers/scripts/convert_stable_audio.py new file mode 100755 index 0000000..a0f9d0f --- /dev/null +++ b/diffusers/scripts/convert_stable_audio.py @@ -0,0 +1,279 @@ +# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. +import argparse +import json +import os +from contextlib import nullcontext + +import torch +from safetensors.torch import load_file +from transformers import ( + AutoTokenizer, + T5EncoderModel, +) + +from diffusers import ( + AutoencoderOobleck, + CosineDPMSolverMultistepScheduler, + StableAudioDiTModel, + StableAudioPipeline, + StableAudioProjectionModel, +) +from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.utils import is_accelerate_available + + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5): + projection_model_state_dict = { + k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding"): v + for (k, v) in state_dict.items() + if "conditioner.conditioners" in k + } + + # NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is. + for key, value in list(projection_model_state_dict.items()): + new_key = key.replace("seconds_start", "start_number_conditioner").replace( + "seconds_total", "end_number_conditioner" + ) + projection_model_state_dict[new_key] = projection_model_state_dict.pop(key) + + model_state_dict = {k.replace("model.model.", ""): v for (k, v) in state_dict.items() if "model.model." in k} + for key, value in list(model_state_dict.items()): + # attention layers + new_key = ( + key.replace("transformer.", "") + .replace("layers", "transformer_blocks") + .replace("self_attn", "attn1") + .replace("cross_attn", "attn2") + .replace("ff.ff", "ff.net") + ) + new_key = ( + new_key.replace("pre_norm", "norm1") + .replace("cross_attend_norm", "norm2") + .replace("ff_norm", "norm3") + .replace("to_out", "to_out.0") + ) + new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm + + # other layers + new_key = ( + new_key.replace("project", "proj") + .replace("to_timestep_embed", "timestep_proj") + .replace("timestep_features", "time_proj") + .replace("to_global_embed", "global_proj") + .replace("to_cond_embed", "cross_attention_proj") + ) + + # we're using diffusers implementation of time_proj (GaussianFourierProjection) which creates a 1D tensor + if new_key == "time_proj.weight": + model_state_dict[key] = model_state_dict[key].squeeze(1) + + if "to_qkv" in new_key: + q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0) + model_state_dict[new_key.replace("qkv", "q")] = q + model_state_dict[new_key.replace("qkv", "k")] = k + model_state_dict[new_key.replace("qkv", "v")] = v + elif "to_kv" in new_key: + k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0) + model_state_dict[new_key.replace("kv", "k")] = k + model_state_dict[new_key.replace("kv", "v")] = v + else: + model_state_dict[new_key] = model_state_dict.pop(key) + + autoencoder_state_dict = { + k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1"): v + for (k, v) in state_dict.items() + if "pretransform.model." in k + } + + for key, _ in list(autoencoder_state_dict.items()): + new_key = key + if "coder.layers" in new_key: + # get idx of the layer + idx = int(new_key.split("coder.layers.")[1].split(".")[0]) + + new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}") + + if "encoder" in new_key: + for i in range(3): + new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}") + new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1") + new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1") + else: + for i in range(2, 5): + new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}") + new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1") + new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1") + + new_key = new_key.replace("layers.0.beta", "snake1.beta") + new_key = new_key.replace("layers.0.alpha", "snake1.alpha") + new_key = new_key.replace("layers.2.beta", "snake2.beta") + new_key = new_key.replace("layers.2.alpha", "snake2.alpha") + new_key = new_key.replace("layers.1.bias", "conv1.bias") + new_key = new_key.replace("layers.1.weight_", "conv1.weight_") + new_key = new_key.replace("layers.3.bias", "conv2.bias") + new_key = new_key.replace("layers.3.weight_", "conv2.weight_") + + if idx == num_autoencoder_layers + 1: + new_key = new_key.replace(f"block.{idx-1}", "snake1") + elif idx == num_autoencoder_layers + 2: + new_key = new_key.replace(f"block.{idx-1}", "conv2") + + else: + new_key = new_key + + value = autoencoder_state_dict.pop(key) + if "snake" in new_key: + value = value.unsqueeze(0).unsqueeze(-1) + if new_key in autoencoder_state_dict: + raise ValueError(f"{new_key} already in state dict.") + autoencoder_state_dict[new_key] = value + + return model_state_dict, projection_model_state_dict, autoencoder_state_dict + + +parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline") +parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config") +parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") +parser.add_argument( + "--save_directory", + type=str, + default="./tmp/stable-audio-1.0", + help="Directory to save a pipeline to. Will be created if it doesn't exist.", +) +parser.add_argument( + "--repo_id", + type=str, + default="stable-audio-1.0", + help="Hub organization to save the pipelines to", +) +parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") +parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") + +args = parser.parse_args() + +checkpoint_path = ( + os.path.join(args.model_folder_path, "model.safetensors") + if args.use_safetensors + else os.path.join(args.model_folder_path, "model.ckpt") +) +config_path = os.path.join(args.model_folder_path, "model_config.json") + +device = "cpu" +if args.variant == "bf16": + dtype = torch.bfloat16 +else: + dtype = torch.float32 + +with open(config_path) as f_in: + config_dict = json.load(f_in) + +conditioning_dict = { + conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"] +} + +t5_model_config = conditioning_dict["prompt"] + +# T5 Text encoder +text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"]) +tokenizer = AutoTokenizer.from_pretrained( + t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"] +) + + +# scheduler +scheduler = CosineDPMSolverMultistepScheduler( + sigma_min=0.3, + sigma_max=500, + solver_order=2, + prediction_type="v_prediction", + sigma_data=1.0, + sigma_schedule="exponential", +) +ctx = init_empty_weights if is_accelerate_available() else nullcontext + + +if args.use_safetensors: + orig_state_dict = load_file(checkpoint_path, device=device) +else: + orig_state_dict = torch.load(checkpoint_path, map_location=device) + + +model_config = config_dict["model"]["diffusion"]["config"] + +model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers( + orig_state_dict +) + + +with ctx(): + projection_model = StableAudioProjectionModel( + text_encoder_dim=text_encoder.config.d_model, + conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"], + min_value=conditioning_dict["seconds_start"][ + "min_val" + ], # assume `seconds_start` and `seconds_total` have the same min / max values. + max_value=conditioning_dict["seconds_start"][ + "max_val" + ], # assume `seconds_start` and `seconds_total` have the same min / max values. + ) +if is_accelerate_available(): + load_model_dict_into_meta(projection_model, projection_model_state_dict) +else: + projection_model.load_state_dict(projection_model_state_dict) + +attention_head_dim = model_config["embed_dim"] // model_config["num_heads"] +with ctx(): + model = StableAudioDiTModel( + sample_size=int(config_dict["sample_size"]) + / int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]), + in_channels=model_config["io_channels"], + num_layers=model_config["depth"], + attention_head_dim=attention_head_dim, + num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim, + num_attention_heads=model_config["num_heads"], + out_channels=model_config["io_channels"], + cross_attention_dim=model_config["cond_token_dim"], + time_proj_dim=256, + global_states_input_dim=model_config["global_cond_dim"], + cross_attention_input_dim=model_config["cond_token_dim"], + ) +if is_accelerate_available(): + load_model_dict_into_meta(model, model_state_dict) +else: + model.load_state_dict(model_state_dict) + + +autoencoder_config = config_dict["model"]["pretransform"]["config"] +with ctx(): + autoencoder = AutoencoderOobleck( + encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"], + downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"], + decoder_channels=autoencoder_config["decoder"]["config"]["channels"], + decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"], + audio_channels=autoencoder_config["io_channels"], + channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"], + sampling_rate=config_dict["sample_rate"], + ) + +if is_accelerate_available(): + load_model_dict_into_meta(autoencoder, autoencoder_state_dict) +else: + autoencoder.load_state_dict(autoencoder_state_dict) + + +# Prior pipeline +pipeline = StableAudioPipeline( + transformer=model, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, + vae=autoencoder, + projection_model=projection_model, +) +pipeline.to(dtype).save_pretrained( + args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant +) diff --git a/diffusers/scripts/convert_stable_cascade.py b/diffusers/scripts/convert_stable_cascade.py new file mode 100755 index 0000000..ce10970 --- /dev/null +++ b/diffusers/scripts/convert_stable_cascade.py @@ -0,0 +1,218 @@ +# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. +import argparse +from contextlib import nullcontext + +import torch +from safetensors.torch import load_file +from transformers import ( + AutoTokenizer, + CLIPConfig, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + DDPMWuerstchenScheduler, + StableCascadeCombinedPipeline, + StableCascadeDecoderPipeline, + StableCascadePriorPipeline, +) +from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers +from diffusers.models import StableCascadeUNet +from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.pipelines.wuerstchen import PaellaVQModel +from diffusers.utils import is_accelerate_available + + +if is_accelerate_available(): + from accelerate import init_empty_weights + +parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline") +parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights") +parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file") +parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file") +parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c") +parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b") +parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") +parser.add_argument( + "--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to" +) +parser.add_argument( + "--decoder_output_path", + type=str, + default="stable-cascade-decoder", + help="Hub organization to save the pipelines to", +) +parser.add_argument( + "--combined_output_path", + type=str, + default="stable-cascade-combined", + help="Hub organization to save the pipelines to", +) +parser.add_argument("--save_combined", action="store_true") +parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") +parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") + +args = parser.parse_args() + +if args.skip_stage_b and args.skip_stage_c: + raise ValueError("At least one stage should be converted") +if (args.skip_stage_b or args.skip_stage_c) and args.save_combined: + raise ValueError("Cannot skip stages when creating a combined pipeline") + +model_path = args.model_path + +device = "cpu" +if args.variant == "bf16": + dtype = torch.bfloat16 +else: + dtype = torch.float32 + +# set paths to model weights +prior_checkpoint_path = f"{model_path}/{args.stage_c_name}" +decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}" + +# Clip Text encoder and tokenizer +config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") +config.text_config.projection_dim = config.projection_dim +text_encoder = CLIPTextModelWithProjection.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config +) +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + +# image processor +feature_extractor = CLIPImageProcessor() +image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + +# scheduler for prior and decoder +scheduler = DDPMWuerstchenScheduler() +ctx = init_empty_weights if is_accelerate_available() else nullcontext + +if not args.skip_stage_c: + # Prior + if args.use_safetensors: + prior_orig_state_dict = load_file(prior_checkpoint_path, device=device) + else: + prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device) + + prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict) + + with ctx(): + prior_model = StableCascadeUNet( + in_channels=16, + out_channels=16, + timestep_ratio_embedding_dim=64, + patch_size=1, + conditioning_dim=2048, + block_out_channels=[2048, 2048], + num_attention_heads=[32, 32], + down_num_layers_per_block=[8, 24], + up_num_layers_per_block=[24, 8], + down_blocks_repeat_mappers=[1, 1], + up_blocks_repeat_mappers=[1, 1], + block_types_per_layer=[ + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ], + clip_text_in_channels=1280, + clip_text_pooled_in_channels=1280, + clip_image_in_channels=768, + clip_seq=4, + kernel_size=3, + dropout=[0.1, 0.1], + self_attn=True, + timestep_conditioning_type=["sca", "crp"], + switch_level=[False], + ) + if is_accelerate_available(): + load_model_dict_into_meta(prior_model, prior_state_dict) + else: + prior_model.load_state_dict(prior_state_dict) + + # Prior pipeline + prior_pipeline = StableCascadePriorPipeline( + prior=prior_model, + tokenizer=tokenizer, + text_encoder=text_encoder, + image_encoder=image_encoder, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + prior_pipeline.to(dtype).save_pretrained( + args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant + ) + +if not args.skip_stage_b: + # Decoder + if args.use_safetensors: + decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device) + else: + decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device) + + decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict) + with ctx(): + decoder = StableCascadeUNet( + in_channels=4, + out_channels=4, + timestep_ratio_embedding_dim=64, + patch_size=2, + conditioning_dim=1280, + block_out_channels=[320, 640, 1280, 1280], + down_num_layers_per_block=[2, 6, 28, 6], + up_num_layers_per_block=[6, 28, 6, 2], + down_blocks_repeat_mappers=[1, 1, 1, 1], + up_blocks_repeat_mappers=[3, 3, 2, 2], + num_attention_heads=[0, 0, 20, 20], + block_types_per_layer=[ + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ], + clip_text_pooled_in_channels=1280, + clip_seq=4, + effnet_in_channels=16, + pixel_mapper_in_channels=3, + kernel_size=3, + dropout=[0, 0, 0.1, 0.1], + self_attn=True, + timestep_conditioning_type=["sca"], + ) + + if is_accelerate_available(): + load_model_dict_into_meta(decoder, decoder_state_dict) + else: + decoder.load_state_dict(decoder_state_dict) + + # VQGAN from Wuerstchen-V2 + vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan") + + # Decoder pipeline + decoder_pipeline = StableCascadeDecoderPipeline( + decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler + ) + decoder_pipeline.to(dtype).save_pretrained( + args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant + ) + +if args.save_combined: + # Stable Cascade combined pipeline + stable_cascade_pipeline = StableCascadeCombinedPipeline( + # Decoder + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqmodel, + # Prior + prior_text_encoder=text_encoder, + prior_tokenizer=tokenizer, + prior_prior=prior_model, + prior_scheduler=scheduler, + prior_image_encoder=image_encoder, + prior_feature_extractor=feature_extractor, + ) + stable_cascade_pipeline.to(dtype).save_pretrained( + args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant + ) diff --git a/diffusers/scripts/convert_stable_cascade_lite.py b/diffusers/scripts/convert_stable_cascade_lite.py new file mode 100755 index 0000000..ddccaa3 --- /dev/null +++ b/diffusers/scripts/convert_stable_cascade_lite.py @@ -0,0 +1,226 @@ +# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. +import argparse +from contextlib import nullcontext + +import torch +from safetensors.torch import load_file +from transformers import ( + AutoTokenizer, + CLIPConfig, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + DDPMWuerstchenScheduler, + StableCascadeCombinedPipeline, + StableCascadeDecoderPipeline, + StableCascadePriorPipeline, +) +from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers +from diffusers.models import StableCascadeUNet +from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.pipelines.wuerstchen import PaellaVQModel +from diffusers.utils import is_accelerate_available + + +if is_accelerate_available(): + from accelerate import init_empty_weights + +parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline") +parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights") +parser.add_argument( + "--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file" +) +parser.add_argument( + "--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file" +) +parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c") +parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b") +parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") +parser.add_argument( + "--prior_output_path", + default="stable-cascade-prior-lite", + type=str, + help="Hub organization to save the pipelines to", +) +parser.add_argument( + "--decoder_output_path", + type=str, + default="stable-cascade-decoder-lite", + help="Hub organization to save the pipelines to", +) +parser.add_argument( + "--combined_output_path", + type=str, + default="stable-cascade-combined-lite", + help="Hub organization to save the pipelines to", +) +parser.add_argument("--save_combined", action="store_true") +parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") +parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") + +args = parser.parse_args() + +if args.skip_stage_b and args.skip_stage_c: + raise ValueError("At least one stage should be converted") +if (args.skip_stage_b or args.skip_stage_c) and args.save_combined: + raise ValueError("Cannot skip stages when creating a combined pipeline") + +model_path = args.model_path + +device = "cpu" +if args.variant == "bf16": + dtype = torch.bfloat16 +else: + dtype = torch.float32 + +# set paths to model weights +prior_checkpoint_path = f"{model_path}/{args.stage_c_name}" +decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}" + +# Clip Text encoder and tokenizer +config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") +config.text_config.projection_dim = config.projection_dim +text_encoder = CLIPTextModelWithProjection.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config +) +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + +# image processor +feature_extractor = CLIPImageProcessor() +image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") +# scheduler for prior and decoder +scheduler = DDPMWuerstchenScheduler() + +ctx = init_empty_weights if is_accelerate_available() else nullcontext + +if not args.skip_stage_c: + # Prior + if args.use_safetensors: + prior_orig_state_dict = load_file(prior_checkpoint_path, device=device) + else: + prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device) + + prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict) + with ctx(): + prior_model = StableCascadeUNet( + in_channels=16, + out_channels=16, + timestep_ratio_embedding_dim=64, + patch_size=1, + conditioning_dim=1536, + block_out_channels=[1536, 1536], + num_attention_heads=[24, 24], + down_num_layers_per_block=[4, 12], + up_num_layers_per_block=[12, 4], + down_blocks_repeat_mappers=[1, 1], + up_blocks_repeat_mappers=[1, 1], + block_types_per_layer=[ + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ], + clip_text_in_channels=1280, + clip_text_pooled_in_channels=1280, + clip_image_in_channels=768, + clip_seq=4, + kernel_size=3, + dropout=[0.1, 0.1], + self_attn=True, + timestep_conditioning_type=["sca", "crp"], + switch_level=[False], + ) + + if is_accelerate_available(): + load_model_dict_into_meta(prior_model, prior_state_dict) + else: + prior_model.load_state_dict(prior_state_dict) + + # Prior pipeline + prior_pipeline = StableCascadePriorPipeline( + prior=prior_model, + tokenizer=tokenizer, + text_encoder=text_encoder, + image_encoder=image_encoder, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + prior_pipeline.to(dtype).save_pretrained( + args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant + ) + +if not args.skip_stage_b: + # Decoder + if args.use_safetensors: + decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device) + else: + decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device) + + decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict) + + with ctx(): + decoder = StableCascadeUNet( + in_channels=4, + out_channels=4, + timestep_ratio_embedding_dim=64, + patch_size=2, + conditioning_dim=1280, + block_out_channels=[320, 576, 1152, 1152], + down_num_layers_per_block=[2, 4, 14, 4], + up_num_layers_per_block=[4, 14, 4, 2], + down_blocks_repeat_mappers=[1, 1, 1, 1], + up_blocks_repeat_mappers=[2, 2, 2, 2], + num_attention_heads=[0, 9, 18, 18], + block_types_per_layer=[ + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ], + clip_text_pooled_in_channels=1280, + clip_seq=4, + effnet_in_channels=16, + pixel_mapper_in_channels=3, + kernel_size=3, + dropout=[0, 0, 0.1, 0.1], + self_attn=True, + timestep_conditioning_type=["sca"], + ) + + if is_accelerate_available(): + load_model_dict_into_meta(decoder, decoder_state_dict) + else: + decoder.load_state_dict(decoder_state_dict) + + # VQGAN from Wuerstchen-V2 + vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan") + + # Decoder pipeline + decoder_pipeline = StableCascadeDecoderPipeline( + decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler + ) + decoder_pipeline.to(dtype).save_pretrained( + args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant + ) + +if args.save_combined: + # Stable Cascade combined pipeline + stable_cascade_pipeline = StableCascadeCombinedPipeline( + # Decoder + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqmodel, + # Prior + prior_text_encoder=text_encoder, + prior_tokenizer=tokenizer, + prior_prior=prior_model, + prior_scheduler=scheduler, + prior_image_encoder=image_encoder, + prior_feature_extractor=feature_extractor, + ) + stable_cascade_pipeline.to(dtype).save_pretrained( + args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant + ) diff --git a/diffusers/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/diffusers/scripts/convert_stable_diffusion_checkpoint_to_onnx.py new file mode 100755 index 0000000..96546b6 --- /dev/null +++ b/diffusers/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -0,0 +1,265 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import shutil +from pathlib import Path + +import onnx +import torch +from packaging import version +from torch.onnx import export + +from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline + + +is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") + + +def onnx_export( + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, + opset, + use_external_data_format=False, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, + # so we check the torch version for backwards compatibility + if is_torch_less_than_1_11: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + enable_onnx_checker=True, + opset_version=opset, + ) + else: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) + + +@torch.no_grad() +def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False): + dtype = torch.float16 if fp16 else torch.float32 + if fp16 and torch.cuda.is_available(): + device = "cuda" + elif fp16 and not torch.cuda.is_available(): + raise ValueError("`float16` model export is only supported on GPUs with CUDA") + else: + device = "cpu" + pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) + output_path = Path(output_path) + + # TEXT ENCODER + num_tokens = pipeline.text_encoder.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder.config.hidden_size + text_input = pipeline.tokenizer( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)), + output_path=output_path / "text_encoder" / "model.onnx", + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset=opset, + ) + del pipeline.text_encoder + + # UNET + unet_in_channels = pipeline.unet.config.in_channels + unet_sample_size = pipeline.unet.config.sample_size + unet_path = output_path / "unet" / "model.onnx" + onnx_export( + pipeline.unet, + model_args=( + torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + torch.randn(2).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), + False, + ), + output_path=unet_path, + ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + output_names=["out_sample"], # has to be different from "sample" for correct tracing + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, + opset=opset, + use_external_data_format=True, # UNet is > 2GB, so the weights need to be split + ) + unet_model_path = str(unet_path.absolute().as_posix()) + unet_dir = os.path.dirname(unet_model_path) + unet = onnx.load(unet_model_path) + # clean up existing tensor files + shutil.rmtree(unet_dir) + os.mkdir(unet_dir) + # collate external tensor files into one + onnx.save_model( + unet, + unet_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False, + ) + del pipeline.unet + + # VAE ENCODER + vae_encoder = pipeline.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size + # need to get the raw tensor output (sample) from the encoder + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + onnx_export( + vae_encoder, + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype), + False, + ), + output_path=output_path / "vae_encoder" / "model.onnx", + ordered_input_names=["sample", "return_dict"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # VAE DECODER + vae_decoder = pipeline.vae + vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels + # forward only through the decoder part + vae_decoder.forward = vae_encoder.decode + onnx_export( + vae_decoder, + model_args=( + torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + False, + ), + output_path=output_path / "vae_decoder" / "model.onnx", + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + del pipeline.vae + + # SAFETY CHECKER + if pipeline.safety_checker is not None: + safety_checker = pipeline.safety_checker + clip_num_channels = safety_checker.config.vision_config.num_channels + clip_image_size = safety_checker.config.vision_config.image_size + safety_checker.forward = safety_checker.forward_onnx + onnx_export( + pipeline.safety_checker, + model_args=( + torch.randn( + 1, + clip_num_channels, + clip_image_size, + clip_image_size, + ).to(device=device, dtype=dtype), + torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype), + ), + output_path=output_path / "safety_checker" / "model.onnx", + ordered_input_names=["clip_input", "images"], + output_names=["out_images", "has_nsfw_concepts"], + dynamic_axes={ + "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "images": {0: "batch", 1: "height", 2: "width", 3: "channels"}, + }, + opset=opset, + ) + del pipeline.safety_checker + safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker") + feature_extractor = pipeline.feature_extractor + else: + safety_checker = None + feature_extractor = None + + onnx_pipeline = OnnxStableDiffusionPipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), + vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), + text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), + tokenizer=pipeline.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), + scheduler=pipeline.scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=safety_checker is not None, + ) + + onnx_pipeline.save_pretrained(output_path) + print("ONNX pipeline saved to", output_path) + + del pipeline + del onnx_pipeline + _ = OnnxStableDiffusionPipeline.from_pretrained(output_path, provider="CPUExecutionProvider") + print("ONNX pipeline is loadable") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", + ) + + parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--opset", + default=14, + type=int, + help="The version of the ONNX operator set to use.", + ) + parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode") + + args = parser.parse_args() + + convert_models(args.model_path, args.output_path, args.opset, args.fp16) diff --git a/diffusers/scripts/convert_stable_diffusion_controlnet_to_onnx.py b/diffusers/scripts/convert_stable_diffusion_controlnet_to_onnx.py new file mode 100755 index 0000000..4af39b2 --- /dev/null +++ b/diffusers/scripts/convert_stable_diffusion_controlnet_to_onnx.py @@ -0,0 +1,505 @@ +import argparse +import os +import shutil +from pathlib import Path + +import onnx +import onnx_graphsurgeon as gs +import torch +from onnx import shape_inference +from packaging import version +from polygraphy.backend.onnx.loader import fold_constants +from torch.onnx import export + +from diffusers import ( + ControlNetModel, + StableDiffusionControlNetImg2ImgPipeline, +) +from diffusers.models.attention_processor import AttnProcessor +from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + + +is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") +is_torch_2_0_1 = version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.1") + + +class Optimizer: + def __init__(self, onnx_graph, verbose=False): + self.graph = gs.import_onnx(onnx_graph) + self.verbose = verbose + + def info(self, prefix): + if self.verbose: + print( + f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" + ) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + +def optimize(onnx_graph, name, verbose): + opt = Optimizer(onnx_graph, verbose=verbose) + opt.info(name + ": original") + opt.cleanup() + opt.info(name + ": cleanup") + opt.fold_constants() + opt.info(name + ": fold constants") + # opt.infer_shapes() + # opt.info(name + ': shape inference') + onnx_opt_graph = opt.cleanup(return_onnx=True) + opt.info(name + ": finished") + return onnx_opt_graph + + +class UNet2DConditionControlNetModel(torch.nn.Module): + def __init__( + self, + unet, + controlnets: ControlNetModel, + ): + super().__init__() + self.unet = unet + self.controlnets = controlnets + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + controlnet_conds, + controlnet_scales, + ): + for i, (controlnet_cond, conditioning_scale, controlnet) in enumerate( + zip(controlnet_conds, controlnet_scales, self.controlnets) + ): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond, + conditioning_scale=conditioning_scale, + return_dict=False, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + noise_pred = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + return noise_pred + + +class UNet2DConditionXLControlNetModel(torch.nn.Module): + def __init__( + self, + unet, + controlnets: ControlNetModel, + ): + super().__init__() + self.unet = unet + self.controlnets = controlnets + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + controlnet_conds, + controlnet_scales, + text_embeds, + time_ids, + ): + added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} + for i, (controlnet_cond, conditioning_scale, controlnet) in enumerate( + zip(controlnet_conds, controlnet_scales, self.controlnets) + ): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond, + conditioning_scale=conditioning_scale, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + noise_pred = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + return noise_pred + + +def onnx_export( + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, + opset, + use_external_data_format=False, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, + # so we check the torch version for backwards compatibility + with torch.inference_mode(), torch.autocast("cuda"): + if is_torch_less_than_1_11: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + enable_onnx_checker=True, + opset_version=opset, + ) + else: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) + + +@torch.no_grad() +def convert_models( + model_path: str, controlnet_path: list, output_path: str, opset: int, fp16: bool = False, sd_xl: bool = False +): + """ + Function to convert models in stable diffusion controlnet pipeline into ONNX format + + Example: + python convert_stable_diffusion_controlnet_to_onnx.py + --model_path danbrown/RevAnimated-v1-2-2 + --controlnet_path lllyasviel/control_v11f1e_sd15_tile ioclab/brightness-controlnet + --output_path path-to-models-stable_diffusion/RevAnimated-v1-2-2 + --fp16 + + Example for SD XL: + python convert_stable_diffusion_controlnet_to_onnx.py + --model_path stabilityai/stable-diffusion-xl-base-1.0 + --controlnet_path SargeZT/sdxl-controlnet-seg + --output_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0 + --fp16 + --sd_xl + + Returns: + create 4 onnx models in output path + text_encoder/model.onnx + unet/model.onnx + unet/weights.pb + vae_encoder/model.onnx + vae_decoder/model.onnx + + run test script in diffusers/examples/community + python test_onnx_controlnet.py + --sd_model danbrown/RevAnimated-v1-2-2 + --onnx_model_dir path-to-models-stable_diffusion/RevAnimated-v1-2-2 + --qr_img_path path-to-qr-code-image + """ + dtype = torch.float16 if fp16 else torch.float32 + if fp16 and torch.cuda.is_available(): + device = "cuda" + elif fp16 and not torch.cuda.is_available(): + raise ValueError("`float16` model export is only supported on GPUs with CUDA") + else: + device = "cpu" + + # init controlnet + controlnets = [] + for path in controlnet_path: + controlnet = ControlNetModel.from_pretrained(path, torch_dtype=dtype).to(device) + if is_torch_2_0_1: + controlnet.set_attn_processor(AttnProcessor()) + controlnets.append(controlnet) + + if sd_xl: + if len(controlnets) == 1: + controlnet = controlnets[0] + else: + raise ValueError("MultiControlNet is not yet supported.") + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + model_path, controlnet=controlnet, torch_dtype=dtype, variant="fp16", use_safetensors=True + ).to(device) + else: + pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + model_path, controlnet=controlnets, torch_dtype=dtype + ).to(device) + + output_path = Path(output_path) + if is_torch_2_0_1: + pipeline.unet.set_attn_processor(AttnProcessor()) + pipeline.vae.set_attn_processor(AttnProcessor()) + + # # TEXT ENCODER + num_tokens = pipeline.text_encoder.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder.config.hidden_size + text_input = pipeline.tokenizer( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)), + output_path=output_path / "text_encoder" / "model.onnx", + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset=opset, + ) + del pipeline.text_encoder + + # # UNET + if sd_xl: + controlnets = torch.nn.ModuleList(controlnets) + unet_controlnet = UNet2DConditionXLControlNetModel(pipeline.unet, controlnets) + unet_in_channels = pipeline.unet.config.in_channels + unet_sample_size = pipeline.unet.config.sample_size + text_hidden_size = 2048 + img_size = 8 * unet_sample_size + unet_path = output_path / "unet" / "model.onnx" + + onnx_export( + unet_controlnet, + model_args=( + torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + torch.tensor([1.0]).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), + torch.randn(len(controlnets), 2, 3, img_size, img_size).to(device=device, dtype=dtype), + torch.randn(len(controlnets), 1).to(device=device, dtype=dtype), + torch.randn(2, 1280).to(device=device, dtype=dtype), + torch.rand(2, 6).to(device=device, dtype=dtype), + ), + output_path=unet_path, + ordered_input_names=[ + "sample", + "timestep", + "encoder_hidden_states", + "controlnet_conds", + "conditioning_scales", + "text_embeds", + "time_ids", + ], + output_names=["noise_pred"], # has to be different from "sample" for correct tracing + dynamic_axes={ + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "controlnet_conds": {1: "2B", 3: "8H", 4: "8W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + }, + opset=opset, + use_external_data_format=True, # UNet is > 2GB, so the weights need to be split + ) + unet_model_path = str(unet_path.absolute().as_posix()) + unet_dir = os.path.dirname(unet_model_path) + # optimize onnx + shape_inference.infer_shapes_path(unet_model_path, unet_model_path) + unet_opt_graph = optimize(onnx.load(unet_model_path), name="Unet", verbose=True) + # clean up existing tensor files + shutil.rmtree(unet_dir) + os.mkdir(unet_dir) + # collate external tensor files into one + onnx.save_model( + unet_opt_graph, + unet_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False, + ) + del pipeline.unet + else: + controlnets = torch.nn.ModuleList(controlnets) + unet_controlnet = UNet2DConditionControlNetModel(pipeline.unet, controlnets) + unet_in_channels = pipeline.unet.config.in_channels + unet_sample_size = pipeline.unet.config.sample_size + img_size = 8 * unet_sample_size + unet_path = output_path / "unet" / "model.onnx" + + onnx_export( + unet_controlnet, + model_args=( + torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + torch.tensor([1.0]).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), + torch.randn(len(controlnets), 2, 3, img_size, img_size).to(device=device, dtype=dtype), + torch.randn(len(controlnets), 1).to(device=device, dtype=dtype), + ), + output_path=unet_path, + ordered_input_names=[ + "sample", + "timestep", + "encoder_hidden_states", + "controlnet_conds", + "conditioning_scales", + ], + output_names=["noise_pred"], # has to be different from "sample" for correct tracing + dynamic_axes={ + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "controlnet_conds": {1: "2B", 3: "8H", 4: "8W"}, + }, + opset=opset, + use_external_data_format=True, # UNet is > 2GB, so the weights need to be split + ) + unet_model_path = str(unet_path.absolute().as_posix()) + unet_dir = os.path.dirname(unet_model_path) + # optimize onnx + shape_inference.infer_shapes_path(unet_model_path, unet_model_path) + unet_opt_graph = optimize(onnx.load(unet_model_path), name="Unet", verbose=True) + # clean up existing tensor files + shutil.rmtree(unet_dir) + os.mkdir(unet_dir) + # collate external tensor files into one + onnx.save_model( + unet_opt_graph, + unet_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False, + ) + del pipeline.unet + + # VAE ENCODER + vae_encoder = pipeline.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size + # need to get the raw tensor output (sample) from the encoder + vae_encoder.forward = lambda sample: vae_encoder.encode(sample).latent_dist.sample() + onnx_export( + vae_encoder, + model_args=(torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),), + output_path=output_path / "vae_encoder" / "model.onnx", + ordered_input_names=["sample"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # VAE DECODER + vae_decoder = pipeline.vae + vae_latent_channels = vae_decoder.config.latent_channels + # forward only through the decoder part + vae_decoder.forward = vae_encoder.decode + onnx_export( + vae_decoder, + model_args=( + torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + ), + output_path=output_path / "vae_decoder" / "model.onnx", + ordered_input_names=["latent_sample"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + del pipeline.vae + + del pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--sd_xl", action="store_true", default=False, help="SD XL pipeline") + + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", + ) + + parser.add_argument( + "--controlnet_path", + nargs="+", + required=True, + help="Path to the `controlnet` checkpoint to convert (either a local directory or on the Hub).", + ) + + parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--opset", + default=14, + type=int, + help="The version of the ONNX operator set to use.", + ) + parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode") + + args = parser.parse_args() + + convert_models(args.model_path, args.controlnet_path, args.output_path, args.opset, args.fp16, args.sd_xl) diff --git a/diffusers/scripts/convert_stable_diffusion_controlnet_to_tensorrt.py b/diffusers/scripts/convert_stable_diffusion_controlnet_to_tensorrt.py new file mode 100755 index 0000000..52ab02c --- /dev/null +++ b/diffusers/scripts/convert_stable_diffusion_controlnet_to_tensorrt.py @@ -0,0 +1,121 @@ +import argparse +import sys + +import tensorrt as trt + + +def convert_models(onnx_path: str, num_controlnet: int, output_path: str, fp16: bool = False, sd_xl: bool = False): + """ + Function to convert models in stable diffusion controlnet pipeline into TensorRT format + + Example: + python convert_stable_diffusion_controlnet_to_tensorrt.py + --onnx_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.onnx + --output_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.engine + --fp16 + --num_controlnet 2 + + Example for SD XL: + python convert_stable_diffusion_controlnet_to_tensorrt.py + --onnx_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.onnx + --output_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine + --fp16 + --num_controlnet 1 + --sd_xl + + Returns: + unet/model.engine + + run test script in diffusers/examples/community + python test_onnx_controlnet.py + --sd_model danbrown/RevAnimated-v1-2-2 + --onnx_model_dir path-to-models-stable_diffusion/RevAnimated-v1-2-2 + --unet_engine_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine + --qr_img_path path-to-qr-code-image + """ + # UNET + if sd_xl: + batch_size = 1 + unet_in_channels = 4 + unet_sample_size = 64 + num_tokens = 77 + text_hidden_size = 2048 + img_size = 512 + + text_embeds_shape = (2 * batch_size, 1280) + time_ids_shape = (2 * batch_size, 6) + else: + batch_size = 1 + unet_in_channels = 4 + unet_sample_size = 64 + num_tokens = 77 + text_hidden_size = 768 + img_size = 512 + batch_size = 1 + + latents_shape = (2 * batch_size, unet_in_channels, unet_sample_size, unet_sample_size) + embed_shape = (2 * batch_size, num_tokens, text_hidden_size) + controlnet_conds_shape = (num_controlnet, 2 * batch_size, 3, img_size, img_size) + + TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) + TRT_BUILDER = trt.Builder(TRT_LOGGER) + TRT_RUNTIME = trt.Runtime(TRT_LOGGER) + + network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + onnx_parser = trt.OnnxParser(network, TRT_LOGGER) + + parse_success = onnx_parser.parse_from_file(onnx_path) + for idx in range(onnx_parser.num_errors): + print(onnx_parser.get_error(idx)) + if not parse_success: + sys.exit("ONNX model parsing failed") + print("Load Onnx model done") + + profile = TRT_BUILDER.create_optimization_profile() + + profile.set_shape("sample", latents_shape, latents_shape, latents_shape) + profile.set_shape("encoder_hidden_states", embed_shape, embed_shape, embed_shape) + profile.set_shape("controlnet_conds", controlnet_conds_shape, controlnet_conds_shape, controlnet_conds_shape) + if sd_xl: + profile.set_shape("text_embeds", text_embeds_shape, text_embeds_shape, text_embeds_shape) + profile.set_shape("time_ids", time_ids_shape, time_ids_shape, time_ids_shape) + + config = TRT_BUILDER.create_builder_config() + config.add_optimization_profile(profile) + config.set_preview_feature(trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805, True) + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + + plan = TRT_BUILDER.build_serialized_network(network, config) + if plan is None: + sys.exit("Failed building engine") + print("Succeeded building engine") + + engine = TRT_RUNTIME.deserialize_cuda_engine(plan) + + ## save TRT engine + with open(output_path, "wb") as f: + f.write(engine.serialize()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--sd_xl", action="store_true", default=False, help="SD XL pipeline") + + parser.add_argument( + "--onnx_path", + type=str, + required=True, + help="Path to the onnx checkpoint to convert", + ) + + parser.add_argument("--num_controlnet", type=int) + + parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") + + parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode") + + args = parser.parse_args() + + convert_models(args.onnx_path, args.num_controlnet, args.output_path, args.fp16, args.sd_xl) diff --git a/diffusers/scripts/convert_svd_to_diffusers.py b/diffusers/scripts/convert_svd_to_diffusers.py new file mode 100755 index 0000000..3243ce2 --- /dev/null +++ b/diffusers/scripts/convert_svd_to_diffusers.py @@ -0,0 +1,730 @@ +from diffusers.utils import is_accelerate_available, logging + + +if is_accelerate_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.encoder_config.params + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlockSpatioTemporal" + if resolution in unet_params.attention_resolutions + else "DownBlockSpatioTemporal" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlockSpatioTemporal" + if resolution in unet_params.attention_resolutions + else "UpBlockSpatioTemporal" + ) + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + addition_time_embed_dim = 256 + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params.disable_self_attentions + + if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): + config["num_class_embeds"] = unet_params.num_classes + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, + mid_block_suffix="", +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + if mid_block_suffix is not None: + mid_block_suffix = f".{mid_block_suffix}" + else: + mid_block_suffix = "" + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight": + print("yeyy") + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = new_item.replace("time_stack", "temporal_transformer_blocks") + + new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias") + new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight") + new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias") + new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight") + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + # if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + spatial_resnets = [ + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key + and ( + f"input_blocks.{i}.0.op" not in key + and f"input_blocks.{i}.0.time_stack" not in key + and f"input_blocks.{i}.0.time_mixer" not in key + ) + ] + temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key] + # import ipdb; ipdb.set_trace() + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + # TODO resnet time_mixer.mix_factor + if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_spatial) + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_temporal) + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_spatial) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_temporal) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.0.time_mixer.mix_factor" + ] + new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.2.time_mixer.mix_factor" + ] + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + spatial_resnets = [ + key + for key in output_blocks[i] + if f"output_blocks.{i}.0" in key + and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key) + ] + # import ipdb; ipdb.set_trace() + + temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key] + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key] + if len(attentions): + paths = renew_attention_paths(attentions) + # import ipdb; ipdb.set_trace() + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + spatial_layers = [ + layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer + ] + resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + temporal_layers = [ + layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key + ] + resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + f"output_blocks.{str(i)}.0.time_mixer.mix_factor" + ] + + return new_checkpoint + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # Temporal resnet + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "temporal_res_block.") + + # Spatial resnet + new_item = new_item.replace("conv1", "spatial_res_block.conv1") + new_item = new_item.replace("norm1", "spatial_res_block.norm1") + + new_item = new_item.replace("conv2", "spatial_res_block.conv2") + new_item = new_item.replace("norm2", "spatial_res_block.norm2") + + new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut") + + new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"] + new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"] + + # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + # new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + # new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint diff --git a/diffusers/scripts/convert_tiny_autoencoder_to_diffusers.py b/diffusers/scripts/convert_tiny_autoencoder_to_diffusers.py new file mode 100755 index 0000000..9bb2df9 --- /dev/null +++ b/diffusers/scripts/convert_tiny_autoencoder_to_diffusers.py @@ -0,0 +1,71 @@ +import argparse + +import safetensors.torch + +from diffusers import AutoencoderTiny + + +""" +Example - From the diffusers root directory: + +Download the weights: +```sh +$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors +$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors +``` + +Convert the model: +```sh +$ python scripts/convert_tiny_autoencoder_to_diffusers.py \ + --encoder_ckpt_path taesd_encoder.safetensors \ + --decoder_ckpt_path taesd_decoder.safetensors \ + --dump_path taesd-diffusers +``` +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--encoder_ckpt_path", + default=None, + type=str, + required=True, + help="Path to the encoder ckpt.", + ) + parser.add_argument( + "--decoder_ckpt_path", + default=None, + type=str, + required=True, + help="Path to the decoder ckpt.", + ) + parser.add_argument( + "--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format." + ) + args = parser.parse_args() + + print("Loading the original state_dicts of the encoder and the decoder...") + encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path) + decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path) + + print("Populating the state_dicts in the diffusers format...") + tiny_autoencoder = AutoencoderTiny() + new_state_dict = {} + + # Modify the encoder state dict. + for k in encoder_state_dict: + new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]}) + + # Modify the decoder state dict. + for k in decoder_state_dict: + layer_id = int(k.split(".")[0]) - 1 + new_k = str(layer_id) + "." + ".".join(k.split(".")[1:]) + new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]}) + + # Assertion tests with the original implementation can be found here: + # https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f + tiny_autoencoder.load_state_dict(new_state_dict) + print("Population successful, serializing...") + tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors) diff --git a/diffusers/scripts/convert_unclip_txt2img_to_image_variation.py b/diffusers/scripts/convert_unclip_txt2img_to_image_variation.py new file mode 100755 index 0000000..07f8ebf --- /dev/null +++ b/diffusers/scripts/convert_unclip_txt2img_to_image_variation.py @@ -0,0 +1,41 @@ +import argparse + +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--txt2img_unclip", + default="kakaobrain/karlo-v1-alpha", + type=str, + required=False, + help="The pretrained txt2img unclip.", + ) + + args = parser.parse_args() + + txt2img = UnCLIPPipeline.from_pretrained(args.txt2img_unclip) + + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + img2img = UnCLIPImageVariationPipeline( + decoder=txt2img.decoder, + text_encoder=txt2img.text_encoder, + tokenizer=txt2img.tokenizer, + text_proj=txt2img.text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=txt2img.super_res_first, + super_res_last=txt2img.super_res_last, + decoder_scheduler=txt2img.decoder_scheduler, + super_res_scheduler=txt2img.super_res_scheduler, + ) + + img2img.save_pretrained(args.dump_path) diff --git a/diffusers/scripts/convert_unidiffuser_to_diffusers.py b/diffusers/scripts/convert_unidiffuser_to_diffusers.py new file mode 100755 index 0000000..4c38172 --- /dev/null +++ b/diffusers/scripts/convert_unidiffuser_to_diffusers.py @@ -0,0 +1,786 @@ +# Convert the original UniDiffuser checkpoints into diffusers equivalents. + +import argparse +from argparse import Namespace + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from diffusers import ( + AutoencoderKL, + DPMSolverMultistepScheduler, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, +) + + +SCHEDULER_CONFIG = Namespace( + **{ + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "solver_order": 3, + } +) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint +# config.num_head_channels => num_head_channels +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + num_head_channels=1, +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // num_head_channels // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def create_vae_diffusers_config(config_type): + # Hardcoded for now + if args.config_type == "test": + vae_config = create_vae_diffusers_config_test() + elif args.config_type == "big": + vae_config = create_vae_diffusers_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + return vae_config + + +def create_unidiffuser_unet_config(config_type, version): + # Hardcoded for now + if args.config_type == "test": + unet_config = create_unidiffuser_unet_config_test() + elif args.config_type == "big": + unet_config = create_unidiffuser_unet_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + # Unidiffuser-v1 uses data type embeddings + if version == 1: + unet_config["use_data_type_embedding"] = True + return unet_config + + +def create_text_decoder_config(config_type): + # Hardcoded for now + if args.config_type == "test": + text_decoder_config = create_text_decoder_config_test() + elif args.config_type == "big": + text_decoder_config = create_text_decoder_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + return text_decoder_config + + +# Hardcoded configs for test versions of the UniDiffuser models, corresponding to those in the fast default tests. +def create_vae_diffusers_config_test(): + vae_config = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "block_out_channels": [32, 64], + "latent_channels": 4, + "layers_per_block": 1, + } + return vae_config + + +def create_unidiffuser_unet_config_test(): + unet_config = { + "text_dim": 32, + "clip_img_dim": 32, + "num_text_tokens": 77, + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "num_layers": 2, + "dropout": 0.0, + "norm_num_groups": 32, + "attention_bias": False, + "sample_size": 16, + "patch_size": 2, + "activation_fn": "gelu", + "num_embeds_ada_norm": 1000, + "norm_type": "layer_norm", + "block_type": "unidiffuser", + "pre_layer_norm": False, + "use_timestep_embedding": False, + "norm_elementwise_affine": True, + "use_patch_pos_embed": False, + "ff_final_dropout": True, + "use_data_type_embedding": False, + } + return unet_config + + +def create_text_decoder_config_test(): + text_decoder_config = { + "prefix_length": 77, + "prefix_inner_dim": 32, + "prefix_hidden_dim": 32, + "vocab_size": 1025, # 1024 + 1 for new EOS token + "n_positions": 1024, + "n_embd": 32, + "n_layer": 5, + "n_head": 4, + "n_inner": 37, + "activation_function": "gelu", + "resid_pdrop": 0.1, + "embd_pdrop": 0.1, + "attn_pdrop": 0.1, + "layer_norm_epsilon": 1e-5, + "initializer_range": 0.02, + } + return text_decoder_config + + +# Hardcoded configs for the UniDiffuser V1 model at https://huggingface.co/thu-ml/unidiffuser-v1 +# See also https://github.com/thu-ml/unidiffuser/blob/main/configs/sample_unidiffuser_v1.py +def create_vae_diffusers_config_big(): + vae_config = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + "block_out_channels": [128, 256, 512, 512], + "latent_channels": 4, + "layers_per_block": 2, + } + return vae_config + + +def create_unidiffuser_unet_config_big(): + unet_config = { + "text_dim": 64, + "clip_img_dim": 512, + "num_text_tokens": 77, + "num_attention_heads": 24, + "attention_head_dim": 64, + "in_channels": 4, + "out_channels": 4, + "num_layers": 30, + "dropout": 0.0, + "norm_num_groups": 32, + "attention_bias": False, + "sample_size": 64, + "patch_size": 2, + "activation_fn": "gelu", + "num_embeds_ada_norm": 1000, + "norm_type": "layer_norm", + "block_type": "unidiffuser", + "pre_layer_norm": False, + "use_timestep_embedding": False, + "norm_elementwise_affine": True, + "use_patch_pos_embed": False, + "ff_final_dropout": True, + "use_data_type_embedding": False, + } + return unet_config + + +# From https://huggingface.co/gpt2/blob/main/config.json, the GPT2 checkpoint used by UniDiffuser +def create_text_decoder_config_big(): + text_decoder_config = { + "prefix_length": 77, + "prefix_inner_dim": 768, + "prefix_hidden_dim": 64, + "vocab_size": 50258, # 50257 + 1 for new EOS token + "n_positions": 1024, + "n_embd": 768, + "n_layer": 12, + "n_head": 12, + "n_inner": 3072, + "activation_function": "gelu", + "resid_pdrop": 0.1, + "embd_pdrop": 0.1, + "attn_pdrop": 0.1, + "layer_norm_epsilon": 1e-5, + "initializer_range": 0.02, + } + return text_decoder_config + + +# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint +def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1): + """ + Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL. + """ + # autoencoder_kl.pth ckpt is a torch state dict + vae_state_dict = torch.load(ckpt, map_location="cpu") + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + conv_attn_to_linear(new_checkpoint) + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_checkpoint) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +def convert_uvit_block_to_diffusers_block( + uvit_state_dict, + new_state_dict, + block_prefix, + new_prefix="transformer.transformer_", + skip_connection=False, +): + """ + Maps the keys in a UniDiffuser transformer block (`Block`) to the keys in a diffusers transformer block + (`UTransformerBlock`/`UniDiffuserBlock`). + """ + prefix = new_prefix + block_prefix + if skip_connection: + new_state_dict[prefix + ".skip.skip_linear.weight"] = uvit_state_dict[block_prefix + ".skip_linear.weight"] + new_state_dict[prefix + ".skip.skip_linear.bias"] = uvit_state_dict[block_prefix + ".skip_linear.bias"] + new_state_dict[prefix + ".skip.norm.weight"] = uvit_state_dict[block_prefix + ".norm1.weight"] + new_state_dict[prefix + ".skip.norm.bias"] = uvit_state_dict[block_prefix + ".norm1.bias"] + + # Create the prefix string for out_blocks. + prefix += ".block" + + # Split up attention qkv.weight into to_q.weight, to_k.weight, to_v.weight + qkv = uvit_state_dict[block_prefix + ".attn.qkv.weight"] + new_attn_keys = [".attn1.to_q.weight", ".attn1.to_k.weight", ".attn1.to_v.weight"] + new_attn_keys = [prefix + key for key in new_attn_keys] + shape = qkv.shape[0] // len(new_attn_keys) + for i, attn_key in enumerate(new_attn_keys): + new_state_dict[attn_key] = qkv[i * shape : (i + 1) * shape] + + new_state_dict[prefix + ".attn1.to_out.0.weight"] = uvit_state_dict[block_prefix + ".attn.proj.weight"] + new_state_dict[prefix + ".attn1.to_out.0.bias"] = uvit_state_dict[block_prefix + ".attn.proj.bias"] + new_state_dict[prefix + ".norm1.weight"] = uvit_state_dict[block_prefix + ".norm2.weight"] + new_state_dict[prefix + ".norm1.bias"] = uvit_state_dict[block_prefix + ".norm2.bias"] + new_state_dict[prefix + ".ff.net.0.proj.weight"] = uvit_state_dict[block_prefix + ".mlp.fc1.weight"] + new_state_dict[prefix + ".ff.net.0.proj.bias"] = uvit_state_dict[block_prefix + ".mlp.fc1.bias"] + new_state_dict[prefix + ".ff.net.2.weight"] = uvit_state_dict[block_prefix + ".mlp.fc2.weight"] + new_state_dict[prefix + ".ff.net.2.bias"] = uvit_state_dict[block_prefix + ".mlp.fc2.bias"] + new_state_dict[prefix + ".norm3.weight"] = uvit_state_dict[block_prefix + ".norm3.weight"] + new_state_dict[prefix + ".norm3.bias"] = uvit_state_dict[block_prefix + ".norm3.bias"] + + return uvit_state_dict, new_state_dict + + +def convert_uvit_to_diffusers(ckpt, diffusers_model): + """ + Converts a UniDiffuser uvit_v*.pth checkpoint to a diffusers UniDiffusersModel. + """ + # uvit_v*.pth ckpt is a torch state dict + uvit_state_dict = torch.load(ckpt, map_location="cpu") + + new_state_dict = {} + + # Input layers + new_state_dict["vae_img_in.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"] + new_state_dict["vae_img_in.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"] + new_state_dict["clip_img_in.weight"] = uvit_state_dict["clip_img_embed.weight"] + new_state_dict["clip_img_in.bias"] = uvit_state_dict["clip_img_embed.bias"] + new_state_dict["text_in.weight"] = uvit_state_dict["text_embed.weight"] + new_state_dict["text_in.bias"] = uvit_state_dict["text_embed.bias"] + + new_state_dict["pos_embed"] = uvit_state_dict["pos_embed"] + + # Handle data type token embeddings for UniDiffuser-v1 + if "token_embedding.weight" in uvit_state_dict and diffusers_model.use_data_type_embedding: + new_state_dict["data_type_pos_embed_token"] = uvit_state_dict["pos_embed_token"] + new_state_dict["data_type_token_embedding.weight"] = uvit_state_dict["token_embedding.weight"] + + # Also initialize the PatchEmbedding in UTransformer2DModel with the PatchEmbedding from the checkpoint. + # This isn't used in the current implementation, so might want to remove. + new_state_dict["transformer.pos_embed.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"] + new_state_dict["transformer.pos_embed.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"] + + # Output layers + new_state_dict["transformer.norm_out.weight"] = uvit_state_dict["norm.weight"] + new_state_dict["transformer.norm_out.bias"] = uvit_state_dict["norm.bias"] + + new_state_dict["vae_img_out.weight"] = uvit_state_dict["decoder_pred.weight"] + new_state_dict["vae_img_out.bias"] = uvit_state_dict["decoder_pred.bias"] + new_state_dict["clip_img_out.weight"] = uvit_state_dict["clip_img_out.weight"] + new_state_dict["clip_img_out.bias"] = uvit_state_dict["clip_img_out.bias"] + new_state_dict["text_out.weight"] = uvit_state_dict["text_out.weight"] + new_state_dict["text_out.bias"] = uvit_state_dict["text_out.bias"] + + # in_blocks + in_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "in_blocks" in layer} + for in_block_prefix in list(in_blocks_prefixes): + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, in_block_prefix) + + # mid_block + # Assume there's only one mid block + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, "mid_block") + + # out_blocks + out_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "out_blocks" in layer} + for out_block_prefix in list(out_blocks_prefixes): + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, out_block_prefix, skip_connection=True) + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +def convert_caption_decoder_to_diffusers(ckpt, diffusers_model): + """ + Converts a UniDiffuser caption_decoder.pth checkpoint to a diffusers UniDiffuserTextDecoder. + """ + # caption_decoder.pth ckpt is a torch state dict + checkpoint_state_dict = torch.load(ckpt, map_location="cpu") + decoder_state_dict = {} + # Remove the "module." prefix, if necessary + caption_decoder_key = "module." + for key in checkpoint_state_dict: + if key.startswith(caption_decoder_key): + decoder_state_dict[key.replace(caption_decoder_key, "")] = checkpoint_state_dict.get(key) + else: + decoder_state_dict[key] = checkpoint_state_dict.get(key) + + new_state_dict = {} + + # Encoder and Decoder + new_state_dict["encode_prefix.weight"] = decoder_state_dict["encode_prefix.weight"] + new_state_dict["encode_prefix.bias"] = decoder_state_dict["encode_prefix.bias"] + new_state_dict["decode_prefix.weight"] = decoder_state_dict["decode_prefix.weight"] + new_state_dict["decode_prefix.bias"] = decoder_state_dict["decode_prefix.bias"] + + # Internal GPT2LMHeadModel transformer model + for key, val in decoder_state_dict.items(): + if key.startswith("gpt"): + suffix = key[len("gpt") :] + new_state_dict["transformer" + suffix] = val + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--caption_decoder_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to caption decoder checkpoint to convert.", + ) + parser.add_argument( + "--uvit_checkpoint_path", default=None, type=str, required=False, help="Path to U-ViT checkpoint to convert." + ) + parser.add_argument( + "--vae_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to VAE checkpoint to convert.", + ) + parser.add_argument( + "--pipeline_output_path", + default=None, + type=str, + required=True, + help="Path to save the output pipeline to.", + ) + parser.add_argument( + "--config_type", + default="test", + type=str, + help=( + "Config type to use. Should be 'test' to create small models for testing or 'big' to convert a full" + " checkpoint." + ), + ) + parser.add_argument( + "--version", + default=0, + type=int, + help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.", + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + help="Whether to use safetensors/safe seialization when saving the pipeline.", + ) + + args = parser.parse_args() + + # Convert the VAE model. + if args.vae_checkpoint_path is not None: + vae_config = create_vae_diffusers_config(args.config_type) + vae = AutoencoderKL(**vae_config) + vae = convert_vae_to_diffusers(args.vae_checkpoint_path, vae) + + # Convert the U-ViT ("unet") model. + if args.uvit_checkpoint_path is not None: + unet_config = create_unidiffuser_unet_config(args.config_type, args.version) + unet = UniDiffuserModel(**unet_config) + unet = convert_uvit_to_diffusers(args.uvit_checkpoint_path, unet) + + # Convert the caption decoder ("text_decoder") model. + if args.caption_decoder_checkpoint_path is not None: + text_decoder_config = create_text_decoder_config(args.config_type) + text_decoder = UniDiffuserTextDecoder(**text_decoder_config) + text_decoder = convert_caption_decoder_to_diffusers(args.caption_decoder_checkpoint_path, text_decoder) + + # Scheduler is the same for both the test and big models. + scheduler_config = SCHEDULER_CONFIG + scheduler = DPMSolverMultistepScheduler( + beta_start=scheduler_config.beta_start, + beta_end=scheduler_config.beta_end, + beta_schedule=scheduler_config.beta_schedule, + solver_order=scheduler_config.solver_order, + ) + + if args.config_type == "test": + # Make a small random CLIPTextModel + torch.manual_seed(0) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(clip_text_encoder_config) + clip_tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # Make a small random CLIPVisionModel and accompanying CLIPImageProcessor + torch.manual_seed(0) + clip_image_encoder_config = CLIPVisionConfig( + image_size=32, + patch_size=2, + num_channels=3, + hidden_size=32, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + ) + image_encoder = CLIPVisionModelWithProjection(clip_image_encoder_config) + image_processor = CLIPImageProcessor(crop_size=32, size=32) + + # Note that the text_decoder should already have its token embeddings resized. + text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model") + eos = "<|EOS|>" + special_tokens_dict = {"eos_token": eos} + text_tokenizer.add_special_tokens(special_tokens_dict) + elif args.config_type == "big": + text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") + image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") + + # Note that the text_decoder should already have its token embeddings resized. + text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + eos = "<|EOS|>" + special_tokens_dict = {"eos_token": eos} + text_tokenizer.add_special_tokens(special_tokens_dict) + else: + raise NotImplementedError( + f"Config type {args.config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + + pipeline = UniDiffuserPipeline( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + clip_image_processor=image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization) diff --git a/diffusers/scripts/convert_vae_diff_to_onnx.py b/diffusers/scripts/convert_vae_diff_to_onnx.py new file mode 100755 index 0000000..e023e04 --- /dev/null +++ b/diffusers/scripts/convert_vae_diff_to_onnx.py @@ -0,0 +1,122 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +import torch +from packaging import version +from torch.onnx import export + +from diffusers import AutoencoderKL + + +is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") + + +def onnx_export( + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, + opset, + use_external_data_format=False, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, + # so we check the torch version for backwards compatibility + if is_torch_less_than_1_11: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + enable_onnx_checker=True, + opset_version=opset, + ) + else: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) + + +@torch.no_grad() +def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False): + dtype = torch.float16 if fp16 else torch.float32 + if fp16 and torch.cuda.is_available(): + device = "cuda" + elif fp16 and not torch.cuda.is_available(): + raise ValueError("`float16` model export is only supported on GPUs with CUDA") + else: + device = "cpu" + output_path = Path(output_path) + + # VAE DECODER + vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae") + vae_latent_channels = vae_decoder.config.latent_channels + # forward only through the decoder part + vae_decoder.forward = vae_decoder.decode + onnx_export( + vae_decoder, + model_args=( + torch.randn(1, vae_latent_channels, 25, 25).to(device=device, dtype=dtype), + False, + ), + output_path=output_path / "vae_decoder" / "model.onnx", + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + del vae_decoder + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", + ) + + parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--opset", + default=14, + type=int, + help="The version of the ONNX operator set to use.", + ) + parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode") + + args = parser.parse_args() + print(args.output_path) + convert_models(args.model_path, args.output_path, args.opset, args.fp16) + print("SD: Done: ONNX") diff --git a/diffusers/scripts/convert_vae_pt_to_diffusers.py b/diffusers/scripts/convert_vae_pt_to_diffusers.py new file mode 100755 index 0000000..a4f967c --- /dev/null +++ b/diffusers/scripts/convert_vae_pt_to_diffusers.py @@ -0,0 +1,159 @@ +import argparse +import io + +import requests +import torch +import yaml + +from diffusers import AutoencoderKL +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + assign_to_checkpoint, + conv_attn_to_linear, + create_vae_diffusers_config, + renew_vae_attention_paths, + renew_vae_resnet_paths, +) + + +def custom_convert_ldm_vae_checkpoint(checkpoint, config): + vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def vae_pt_to_vae_diffuser( + checkpoint_path: str, + output_path: str, +): + # Only support V1 + r = requests.get( + " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + io_obj = io.BytesIO(r.content) + + original_config = yaml.safe_load(io_obj) + image_size = 512 + device = "cuda" if torch.cuda.is_available() else "cpu" + if checkpoint_path.endswith("safetensors"): + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"] + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + vae.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") + + args = parser.parse_args() + + vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path) diff --git a/diffusers/scripts/convert_versatile_diffusion_to_diffusers.py b/diffusers/scripts/convert_versatile_diffusion_to_diffusers.py new file mode 100755 index 0000000..41e2e01 --- /dev/null +++ b/diffusers/scripts/convert_versatile_diffusion_to_diffusers.py @@ -0,0 +1,791 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the Versatile Stable Diffusion checkpoints.""" + +import argparse +from argparse import Namespace + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, + VersatileDiffusionPipeline, +) +from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel + + +SCHEDULER_CONFIG = Namespace( + **{ + "beta_linear_start": 0.00085, + "beta_linear_end": 0.012, + "timesteps": 1000, + "scale_factor": 0.18215, + } +) + +IMAGE_UNET_CONFIG = Namespace( + **{ + "input_channels": 4, + "model_channels": 320, + "output_channels": 4, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +TEXT_UNET_CONFIG = Namespace( + **{ + "input_channels": 768, + "model_channels": 320, + "output_channels": 768, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "second_dim": [4, 4, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +AUTOENCODER_CONFIG = Namespace( + **{ + "double_z": True, + "z_channels": 4, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + } +) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif path["old"] in old_checkpoint: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_image_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = { + "sample_size": None, + "in_channels": unet_params.input_channels, + "out_channels": unet_params.output_channels, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_noattn_blocks[0], + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": unet_params.num_heads, + } + + return config + + +def create_text_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = { + "sample_size": None, + "in_channels": (unet_params.input_channels, 1, 1), + "out_channels": (unet_params.output_channels, 1, 1), + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_noattn_blocks[0], + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": unet_params.num_heads, + } + + return config + + +def create_vae_diffusers_config(vae_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": vae_params.resolution, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_scheduler(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100: + print("Checkpoint has both EMA and non-EMA weights.") + if extract_ema: + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + elif f"input_blocks.{i}.0.weight" in unet_state_dict: + # text_unet uses linear layers in place of downsamplers + shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if ["conv.weight", "conv.bias"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.1.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.1.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.1.bias" + ) + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.2.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.2.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.2.bias" + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def convert_vd_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + for key in keys: + vae_state_dict[key] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + scheduler_config = SCHEDULER_CONFIG + + num_train_timesteps = scheduler_config.timesteps + beta_start = scheduler_config.beta_linear_start + beta_end = scheduler_config.beta_linear_end + if args.scheduler_type == "pndm": + scheduler = PNDMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + skip_prk_steps=True, + steps_offset=1, + ) + elif args.scheduler_type == "lms": + scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler": + scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "ddim": + scheduler = DDIMScheduler( + beta_start=beta_start, + beta_end=beta_end, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + else: + raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel models. + if args.unet_checkpoint_path is not None: + # image UNet + image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG) + checkpoint = torch.load(args.unet_checkpoint_path) + converted_image_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema + ) + image_unet = UNet2DConditionModel(**image_unet_config) + image_unet.load_state_dict(converted_image_unet_checkpoint) + + # text UNet + text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG) + converted_text_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema + ) + text_unet = UNetFlatConditionModel(**text_unet_config) + text_unet.load_state_dict(converted_text_unet_checkpoint) + + # Convert the VAE model. + if args.vae_checkpoint_path is not None: + vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG) + checkpoint = torch.load(args.vae_checkpoint_path) + converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + image_feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + pipe = VersatileDiffusionPipeline( + scheduler=scheduler, + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + ) + pipe.save_pretrained(args.dump_path) diff --git a/diffusers/scripts/convert_vq_diffusion_to_diffusers.py b/diffusers/scripts/convert_vq_diffusion_to_diffusers.py new file mode 100755 index 0000000..7da6b40 --- /dev/null +++ b/diffusers/scripts/convert_vq_diffusion_to_diffusers.py @@ -0,0 +1,916 @@ +""" +This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers. + +It currently only supports porting the ITHQ dataset. + +ITHQ dataset: +```sh +# From the root directory of diffusers. + +# Download the VQVAE checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth + +# Download the VQVAE config +# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class +# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE` +# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml` +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml + +# Download the main model checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth + +# Download the main model config +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml + +# run the convert script +$ python ./scripts/convert_vq_diffusion_to_diffusers.py \ + --checkpoint_path ./ithq_learnable.pth \ + --original_config_file ./ithq.yaml \ + --vqvae_checkpoint_path ./ithq_vqvae.pth \ + --vqvae_original_config_file ./ithq_vqvae.yaml \ + --dump_path +``` +""" + +import argparse +import tempfile + +import torch +import yaml +from accelerate import init_empty_weights, load_checkpoint_and_dispatch +from transformers import CLIPTextModel, CLIPTokenizer +from yaml.loader import FullLoader + +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings + + +# vqvae model + +PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] + + +def vqvae_model_from_original_config(original_config): + assert ( + original_config["target"] in PORTED_VQVAES + ), f"{original_config['target']} has not yet been ported to diffusers." + + original_config = original_config["params"] + + original_encoder_config = original_config["encoder_config"]["params"] + original_decoder_config = original_config["decoder_config"]["params"] + + in_channels = original_encoder_config["in_channels"] + out_channels = original_decoder_config["out_ch"] + + down_block_types = get_down_block_types(original_encoder_config) + up_block_types = get_up_block_types(original_decoder_config) + + assert original_encoder_config["ch"] == original_decoder_config["ch"] + assert original_encoder_config["ch_mult"] == original_decoder_config["ch_mult"] + block_out_channels = tuple( + [original_encoder_config["ch"] * a_ch_mult for a_ch_mult in original_encoder_config["ch_mult"]] + ) + + assert original_encoder_config["num_res_blocks"] == original_decoder_config["num_res_blocks"] + layers_per_block = original_encoder_config["num_res_blocks"] + + assert original_encoder_config["z_channels"] == original_decoder_config["z_channels"] + latent_channels = original_encoder_config["z_channels"] + + num_vq_embeddings = original_config["n_embed"] + + # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion + norm_num_groups = 32 + + e_dim = original_config["embed_dim"] + + model = VQModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + latent_channels=latent_channels, + num_vq_embeddings=num_vq_embeddings, + norm_num_groups=norm_num_groups, + vq_embed_dim=e_dim, + ) + + return model + + +def get_down_block_types(original_encoder_config): + attn_resolutions = coerce_attn_resolutions(original_encoder_config["attn_resolutions"]) + num_resolutions = len(original_encoder_config["ch_mult"]) + resolution = coerce_resolution(original_encoder_config["resolution"]) + + curr_res = resolution + down_block_types = [] + + for _ in range(num_resolutions): + if curr_res in attn_resolutions: + down_block_type = "AttnDownEncoderBlock2D" + else: + down_block_type = "DownEncoderBlock2D" + + down_block_types.append(down_block_type) + + curr_res = [r // 2 for r in curr_res] + + return down_block_types + + +def get_up_block_types(original_decoder_config): + attn_resolutions = coerce_attn_resolutions(original_decoder_config["attn_resolutions"]) + num_resolutions = len(original_decoder_config["ch_mult"]) + resolution = coerce_resolution(original_decoder_config["resolution"]) + + curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] + up_block_types = [] + + for _ in reversed(range(num_resolutions)): + if curr_res in attn_resolutions: + up_block_type = "AttnUpDecoderBlock2D" + else: + up_block_type = "UpDecoderBlock2D" + + up_block_types.append(up_block_type) + + curr_res = [r * 2 for r in curr_res] + + return up_block_types + + +def coerce_attn_resolutions(attn_resolutions): + attn_resolutions = list(attn_resolutions) + attn_resolutions_ = [] + for ar in attn_resolutions: + if isinstance(ar, (list, tuple)): + attn_resolutions_.append(list(ar)) + else: + attn_resolutions_.append([ar, ar]) + return attn_resolutions_ + + +def coerce_resolution(resolution): + if isinstance(resolution, int): + resolution = [resolution, resolution] # H, W + elif isinstance(resolution, (tuple, list)): + resolution = list(resolution) + else: + raise ValueError("Unknown type of resolution:", resolution) + return resolution + + +# done vqvae model + +# vqvae checkpoint + + +def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) + + return diffusers_checkpoint + + +def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"], + "decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # group_norm + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +# done vqvae checkpoint + +# transformer model + +PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"] +PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"] +PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"] + + +def transformer_model_from_original_config( + original_diffusion_config, original_transformer_config, original_content_embedding_config +): + assert ( + original_diffusion_config["target"] in PORTED_DIFFUSIONS + ), f"{original_diffusion_config['target']} has not yet been ported to diffusers." + assert ( + original_transformer_config["target"] in PORTED_TRANSFORMERS + ), f"{original_transformer_config['target']} has not yet been ported to diffusers." + assert ( + original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS + ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers." + + original_diffusion_config = original_diffusion_config["params"] + original_transformer_config = original_transformer_config["params"] + original_content_embedding_config = original_content_embedding_config["params"] + + inner_dim = original_transformer_config["n_embd"] + + n_heads = original_transformer_config["n_head"] + + # VQ-Diffusion gives dimension of the multi-headed attention layers as the + # number of attention heads times the sequence length (the dimension) of a + # single head. We want to specify our attention blocks with those values + # specified separately + assert inner_dim % n_heads == 0 + d_head = inner_dim // n_heads + + depth = original_transformer_config["n_layer"] + context_dim = original_transformer_config["condition_dim"] + + num_embed = original_content_embedding_config["num_embed"] + # the number of embeddings in the transformer includes the mask embedding. + # the content embedding (the vqvae) does not include the mask embedding. + num_embed = num_embed + 1 + + height = original_transformer_config["content_spatial_size"][0] + width = original_transformer_config["content_spatial_size"][1] + + assert width == height, "width has to be equal to height" + dropout = original_transformer_config["resid_pdrop"] + num_embeds_ada_norm = original_diffusion_config["diffusion_step"] + + model_kwargs = { + "attention_bias": True, + "cross_attention_dim": context_dim, + "attention_head_dim": d_head, + "num_layers": depth, + "dropout": dropout, + "num_attention_heads": n_heads, + "num_vector_embeds": num_embed, + "num_embeds_ada_norm": num_embeds_ada_norm, + "norm_num_groups": 32, + "sample_size": width, + "activation_fn": "geglu-approximate", + } + + model = Transformer2DModel(**model_kwargs) + return model + + +# done transformer model + +# transformer checkpoint + + +def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + transformer_prefix = "transformer.transformer" + + diffusers_latent_image_embedding_prefix = "latent_image_embedding" + latent_image_embedding_prefix = f"{transformer_prefix}.content_emb" + + # DalleMaskImageEmbedding + diffusers_checkpoint.update( + { + f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.height_emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.width_emb.weight" + ], + } + ) + + # transformer blocks + for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks): + diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}" + transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}" + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1" + ada_norm_prefix = f"{transformer_block_prefix}.ln1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1" + attention_prefix = f"{transformer_block_prefix}.attn1" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2" + ada_norm_prefix = f"{transformer_block_prefix}.ln1_1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2" + attention_prefix = f"{transformer_block_prefix}.attn2" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # norm block + diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3" + norm_block_prefix = f"{transformer_block_prefix}.ln2" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"], + f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"], + } + ) + + # feedforward block + diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff" + feedforward_prefix = f"{transformer_block_prefix}.mlp" + + diffusers_checkpoint.update( + transformer_feedforward_to_diffusers_checkpoint( + checkpoint, + diffusers_feedforward_prefix=diffusers_feedforward_prefix, + feedforward_prefix=feedforward_prefix, + ) + ) + + # to logits + + diffusers_norm_out_prefix = "norm_out" + norm_out_prefix = f"{transformer_prefix}.to_logits.0" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"], + f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"], + } + ) + + diffusers_out_prefix = "out" + out_prefix = f"{transformer_prefix}.to_logits.1" + + diffusers_checkpoint.update( + { + f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"], + f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix): + return { + f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"], + f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"], + f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"], + } + + +def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"], + # linear out + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"], + } + + +def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix): + return { + f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"], + f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"], + f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"], + f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"], + } + + +# done transformer checkpoint + + +def read_config_file(filename): + # The yaml file contains annotations that certain values should + # loaded as tuples. + with open(filename) as f: + original_config = yaml.load(f, FullLoader) + + return original_config + + +# We take separate arguments for the vqvae because the ITHQ vqvae config file +# is separate from the config file for the rest of the model. +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vqvae_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the vqvae checkpoint to convert.", + ) + + parser.add_argument( + "--vqvae_original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture for the vqvae.", + ) + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + # See link for how ema weights are always selected + # https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65 + parser.add_argument( + "--no_use_ema", + action="store_true", + required=False, + help=( + "Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set" + " it as the original VQ-Diffusion always uses the ema weights when loading models." + ), + ) + + args = parser.parse_args() + + use_ema = not args.no_use_ema + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + # vqvae_model + + print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}") + + vqvae_original_config = read_config_file(args.vqvae_original_config_file).model + vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"] + + with init_empty_weights(): + vqvae_model = vqvae_model_from_original_config(vqvae_original_config) + + vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint) + + with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file: + torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name) + del vqvae_diffusers_checkpoint + del vqvae_checkpoint + load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto") + + print("done loading vqvae") + + # done vqvae_model + + # transformer_model + + print( + f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:" + f" {use_ema}" + ) + + original_config = read_config_file(args.original_config_file).model + + diffusion_config = original_config["params"]["diffusion_config"] + transformer_config = original_config["params"]["diffusion_config"]["params"]["transformer_config"] + content_embedding_config = original_config["params"]["diffusion_config"]["params"]["content_emb_config"] + + pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) + + if use_ema: + if "ema" in pre_checkpoint: + checkpoint = {} + for k, v in pre_checkpoint["model"].items(): + checkpoint[k] = v + + for k, v in pre_checkpoint["ema"].items(): + # The ema weights are only used on the transformer. To mimic their key as if they came + # from the state_dict for the top level model, we prefix with an additional "transformer." + # See the source linked in the args.use_ema config for more information. + checkpoint[f"transformer.{k}"] = v + else: + print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.") + checkpoint = pre_checkpoint["model"] + else: + checkpoint = pre_checkpoint["model"] + + del pre_checkpoint + + with init_empty_weights(): + transformer_model = transformer_model_from_original_config( + diffusion_config, transformer_config, content_embedding_config + ) + + diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint( + transformer_model, checkpoint + ) + + # classifier free sampling embeddings interlude + + # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate + # model, so we pull them off the checkpoint before the checkpoint is deleted. + + learnable_classifier_free_sampling_embeddings = diffusion_config["params"].learnable_cf + + if learnable_classifier_free_sampling_embeddings: + learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] + else: + learned_classifier_free_sampling_embeddings_embeddings = None + + # done classifier free sampling embeddings interlude + + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: + torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) + del diffusers_transformer_checkpoint + del checkpoint + load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto") + + print("done loading transformer") + + # done transformer_model + + # text encoder + + print("loading CLIP text encoder") + + clip_name = "openai/clip-vit-base-patch32" + + # The original VQ-Diffusion specifies the pad value by the int used in the + # returned tokens. Each model uses `0` as the pad value. The transformers clip api + # specifies the pad value via the token before it has been tokenized. The `!` pad + # token is the same as padding with the `0` pad value. + pad_token = "!" + + tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") + + assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 + + text_encoder_model = CLIPTextModel.from_pretrained( + clip_name, + # `CLIPTextModel` does not support device_map="auto" + # device_map="auto" + ) + + print("done loading CLIP text encoder") + + # done text encoder + + # scheduler + + scheduler_model = VQDiffusionScheduler( + # the scheduler has the same number of embeddings as the transformer + num_vec_classes=transformer_model.num_vector_embeds + ) + + # done scheduler + + # learned classifier free sampling embeddings + + with init_empty_weights(): + learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings( + learnable_classifier_free_sampling_embeddings, + hidden_size=text_encoder_model.config.hidden_size, + length=tokenizer_model.model_max_length, + ) + + learned_classifier_free_sampling_checkpoint = { + "embeddings": learned_classifier_free_sampling_embeddings_embeddings.float() + } + + with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: + torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name) + del learned_classifier_free_sampling_checkpoint + del learned_classifier_free_sampling_embeddings_embeddings + load_checkpoint_and_dispatch( + learned_classifier_free_sampling_embeddings_model, + learned_classifier_free_sampling_checkpoint_file.name, + device_map="auto", + ) + + # done learned classifier free sampling embeddings + + print(f"saving VQ diffusion model, path: {args.dump_path}") + + pipe = VQDiffusionPipeline( + vqvae=vqvae_model, + transformer=transformer_model, + tokenizer=tokenizer_model, + text_encoder=text_encoder_model, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, + scheduler=scheduler_model, + ) + pipe.save_pretrained(args.dump_path) + + print("done writing VQ diffusion model") diff --git a/diffusers/scripts/convert_wuerstchen.py b/diffusers/scripts/convert_wuerstchen.py new file mode 100755 index 0000000..23d45d3 --- /dev/null +++ b/diffusers/scripts/convert_wuerstchen.py @@ -0,0 +1,115 @@ +# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/ +import os + +import torch +from transformers import AutoTokenizer, CLIPTextModel +from vqgan import VQModel + +from diffusers import ( + DDPMWuerstchenScheduler, + WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, + WuerstchenPriorPipeline, +) +from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior + + +model_path = "models/" +device = "cpu" + +paella_vqmodel = VQModel() +state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"] +paella_vqmodel.load_state_dict(state_dict) + +state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] +state_dict.pop("vquantizer.codebook.weight") +vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) +vqmodel.load_state_dict(state_dict) + +# Clip Text encoder and tokenizer +text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + +# Generator +gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") +gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + +orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"] +state_dict = {} +for key in orig_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = orig_state_dict[key] +deocder = WuerstchenDiffNeXt() +deocder.load_state_dict(state_dict) + +# Prior +orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"] +state_dict = {} +for key in orig_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = orig_state_dict[key] +prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) +prior_model.load_state_dict(state_dict) + +# scheduler +scheduler = DDPMWuerstchenScheduler() + +# Prior pipeline +prior_pipeline = WuerstchenPriorPipeline( + prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler +) + +prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior") + +decoder_pipeline = WuerstchenDecoderPipeline( + text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler +) +decoder_pipeline.save_pretrained("warp-ai/wuerstchen") + +# Wuerstchen pipeline +wuerstchen_pipeline = WuerstchenCombinedPipeline( + # Decoder + text_encoder=gen_text_encoder, + tokenizer=gen_tokenizer, + decoder=deocder, + scheduler=scheduler, + vqgan=vqmodel, + # Prior + prior_tokenizer=tokenizer, + prior_text_encoder=text_encoder, + prior=prior_model, + prior_scheduler=scheduler, +) +wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline") diff --git a/diffusers/scripts/convert_zero123_to_diffusers.py b/diffusers/scripts/convert_zero123_to_diffusers.py new file mode 100755 index 0000000..b46633f --- /dev/null +++ b/diffusers/scripts/convert_zero123_to_diffusers.py @@ -0,0 +1,807 @@ +""" +This script modified from +https://github.com/huggingface/diffusers/blob/bc691231360a4cbc7d19a58742ebb8ed0f05e027/scripts/convert_original_stable_diffusion_to_diffusers.py + +Convert original Zero1to3 checkpoint to diffusers checkpoint. + +# run the convert script +$ python convert_zero123_to_diffusers.py \ + --checkpoint_path /path/zero123/105000.ckpt \ + --dump_path ./zero1to3 \ + --original_config_file /path/zero123/configs/sd-objaverse-finetune-c_concat-256.yaml +``` +""" + +import argparse + +import torch +import yaml +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from pipeline_zero1to3 import CCProjection, Zero1to3StableDiffusionPipeline +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, +) + +from diffusers.models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + else: + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params['num_classes']}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if controlnet: + config["conditioning_channels"] = unet_params["hint_channels"] + else: + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint[flat_ema_key] + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint[key] + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + return config + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, extract_ema, device): + ckpt = torch.load(checkpoint_path, map_location=device) + ckpt["global_step"] + checkpoint = ckpt["state_dict"] + del ckpt + torch.cuda.empty_cache() + + original_config = yaml.safe_load(original_config_file) + original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + num_in_channels = 8 + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + prediction_type = "epsilon" + image_size = 256 + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000 + + beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02 + beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + scheduler.register_to_config(clip_sample=False) + + # Convert the UNet2DConditionModel model. + upcast_attention = None + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + with init_empty_weights(): + unet = UNet2DConditionModel(**unet_config) + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=None, extract_ema=extract_ema + ) + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config["model"] + and "scale_factor" in original_config["model"]["params"] + ): + vae_scaling_factor = original_config["model"]["params"]["scale_factor"] + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + with init_empty_weights(): + vae = AutoencoderKL(**vae_config) + + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + + feature_extractor = CLIPImageProcessor.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", subfolder="feature_extractor" + ) + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", subfolder="image_encoder" + ) + + cc_projection = CCProjection() + cc_projection.load_state_dict( + { + "projection.weight": checkpoint["cc_projection.weight"].cpu(), + "projection.bias": checkpoint["cc_projection.bias"].cpu(), + } + ) + + pipe = Zero1to3StableDiffusionPipeline( + vae, image_encoder, unet, scheduler, None, feature_extractor, cc_projection, requires_safety_checker=False + ) + + return pipe + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + args = parser.parse_args() + + pipe = convert_from_original_zero123_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + extract_ema=args.extract_ema, + device=args.device, + ) + + if args.half: + pipe.to(dtype=torch.float16) + + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/diffusers/scripts/generate_logits.py b/diffusers/scripts/generate_logits.py new file mode 100755 index 0000000..99d46d6 --- /dev/null +++ b/diffusers/scripts/generate_logits.py @@ -0,0 +1,127 @@ +import random + +import torch +from huggingface_hub import HfApi + +from diffusers import UNet2DModel + + +api = HfApi() + +results = {} +# fmt: off +results["google_ddpm_cifar10_32"] = torch.tensor([ + -0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467, + 1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189, + -1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839, + 0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557 +]) +results["google_ddpm_ema_bedroom_256"] = torch.tensor([ + -2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436, + 1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208, + -2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948, + 2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365 +]) +results["CompVis_ldm_celebahq_256"] = torch.tensor([ + -0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869, + -0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304, + -0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925, + 0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943 +]) +results["google_ncsnpp_ffhq_1024"] = torch.tensor([ + 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172, + -0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309, + 0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805, + -0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505 +]) +results["google_ncsnpp_bedroom_256"] = torch.tensor([ + 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133, + -0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395, + 0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559, + -0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386 +]) +results["google_ncsnpp_celebahq_256"] = torch.tensor([ + 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078, + -0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330, + 0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683, + -0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431 +]) +results["google_ncsnpp_church_256"] = torch.tensor([ + 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042, + -0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398, + 0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574, + -0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390 +]) +results["google_ncsnpp_ffhq_256"] = torch.tensor([ + 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042, + -0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290, + 0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746, + -0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473 +]) +results["google_ddpm_cat_256"] = torch.tensor([ + -1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330, + 1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243, + -2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810, + 1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251]) +results["google_ddpm_celebahq_256"] = torch.tensor([ + -1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324, + 0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181, + -2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259, + 1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266 +]) +results["google_ddpm_ema_celebahq_256"] = torch.tensor([ + -1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212, + 0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027, + -2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131, + 1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355 +]) +results["google_ddpm_church_256"] = torch.tensor([ + -2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959, + 1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351, + -3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341, + 3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066 +]) +results["google_ddpm_bedroom_256"] = torch.tensor([ + -2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740, + 1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398, + -2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395, + 2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243 +]) +results["google_ddpm_ema_church_256"] = torch.tensor([ + -2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336, + 1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908, + -3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560, + 3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343 +]) +results["google_ddpm_ema_cat_256"] = torch.tensor([ + -1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344, + 1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391, + -2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439, + 1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219 +]) +# fmt: on + +models = api.list_models(filter="diffusers") +for mod in models: + if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256": + local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1] + + print(f"Started running {mod.id}!!!") + + if mod.id.startswith("CompVis"): + model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet") + else: + model = UNet2DModel.from_pretrained(local_checkpoint) + + torch.manual_seed(0) + random.seed(0) + + noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + time_step = torch.tensor([10] * noise.shape[0]) + with torch.no_grad(): + logits = model(noise, time_step).sample + + assert torch.allclose( + logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3 + ) + print(f"{mod.id} has passed successfully!!!") diff --git a/diffusers/setup.py b/diffusers/setup.py new file mode 100755 index 0000000..89e18ca --- /dev/null +++ b/diffusers/setup.py @@ -0,0 +1,286 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py + +To create the package for PyPI. + +1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the + documentation. + + If releasing on a special branch, copy the updated README.md on the main branch for the commit you will make + for the post-release and run `make fix-copies` on the main branch as well. + +2. Unpin specific versions from setup.py that use a git install. + +3. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the + message: "Release: " and push. + +4. Manually trigger the "Nightly and release tests on main/release branch" workflow from the release branch. Wait for + the tests to complete. We can safely ignore the known test failures. + +5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs). + +6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for PyPI'" + Push the tag to git: git push --tags origin v-release + +7. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + + For the wheel, run: "python setup.py bdist_wheel" in the top level directory + (This will build a wheel for the Python version you use to build it). + + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions. + + Long story cut short, you need to run both before you can upload the distribution to the + test PyPI and the actual PyPI servers: + + python setup.py bdist_wheel && python setup.py sdist + +8. Check that everything looks correct by uploading the package to the PyPI test server: + + twine upload dist/* -r pypitest + (pypi suggests using twine as other methods upload files via plaintext.) + You may have to specify the repository url, use the following command then: + twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ + + Check that you can install it in a virtualenv by running: + pip install -i https://testpypi.python.org/pypi diffusers + + If you are testing from a Colab Notebook, for instance, then do: + pip install diffusers && pip uninstall diffusers + pip install -i https://testpypi.python.org/pypi diffusers + + Check you can run the following commands: + python -c "from diffusers import __version__; print(__version__)" + python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()" + python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')" + python -c "from diffusers import *" + +9. Upload the final version to the actual PyPI: + twine upload dist/* -r pypi + +10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory. You can use the following + Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/lysandre/github-release. Repo should + be `huggingface/diffusers`. `tag` should be the previous release tag (v0.26.1, for example), and `branch` should be + the latest release branch (v0.27.0-release, for example). It denotes all commits that have happened on branch + v0.27.0-release after the tag v0.26.1 was created. + +11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, + you need to go back to main before executing this. +""" + +import os +import re +import sys + +from setuptools import Command, find_packages, setup + + +# IMPORTANT: +# 1. all dependencies should be listed here with their version requirements if any +# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py +_deps = [ + "Pillow", # keep the PIL.Image.Resampling deprecation away + "accelerate>=0.31.0", + "compel==0.1.8", + "datasets", + "filelock", + "flax>=0.4.1", + "hf-doc-builder>=0.3.0", + "huggingface-hub>=0.23.2", + "requests-mock==1.10.0", + "importlib_metadata", + "invisible-watermark>=0.2.0", + "isort>=5.5.4", + "jax>=0.4.1", + "jaxlib>=0.4.1", + "Jinja2", + "k-diffusion>=0.0.12", + "torchsde", + "note_seq", + "librosa", + "numpy", + "parameterized", + "peft>=0.6.0", + "protobuf>=3.20.3,<4", + "pytest", + "pytest-timeout", + "pytest-xdist", + "python>=3.8.0", + "ruff==0.1.5", + "safetensors>=0.3.1", + "sentencepiece>=0.1.91,!=0.1.92", + "GitPython<3.1.19", + "scipy", + "onnx", + "regex!=2019.12.17", + "requests", + "tensorboard", + "torch>=1.4", + "torchvision", + "transformers>=4.41.2", + "urllib3<=2.0.0", + "black", +] + +# this is a lookup table with items like: +# +# tokenizers: "huggingface-hub==0.8.0" +# packaging: "packaging" +# +# some of the values are versioned whereas others aren't. +deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)} + +# since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from +# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with: +# +# python -c 'import sys; from diffusers.dependency_versions_table import deps; \ +# print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets +# +# Just pass the desired package names to that script as it's shown with 2 packages above. +# +# If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above +# +# You can then feed this for example to `pip`: +# +# pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \ +# print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets) +# + + +def deps_list(*pkgs): + return [deps[pkg] for pkg in pkgs] + + +class DepsTableUpdateCommand(Command): + """ + A custom command that updates the dependency table. + usage: python setup.py deps_table_update + """ + + description = "build runtime dependency table" + user_options = [ + # format: (long option, short option, description). + ( + "dep-table-update", + None, + "updates src/diffusers/dependency_versions_table.py", + ), + ] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()]) + content = [ + "# THIS FILE HAS BEEN AUTOGENERATED. To update:", + "# 1. modify the `_deps` dict in setup.py", + "# 2. run `make deps_table_update`", + "deps = {", + entries, + "}", + "", + ] + target = "src/diffusers/dependency_versions_table.py" + print(f"updating {target}") + with open(target, "w", encoding="utf-8", newline="\n") as f: + f.write("\n".join(content)) + + +extras = {} +extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder") +extras["docs"] = deps_list("hf-doc-builder") +extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft") +extras["test"] = deps_list( + "compel", + "GitPython", + "datasets", + "Jinja2", + "invisible-watermark", + "k-diffusion", + "librosa", + "parameterized", + "pytest", + "pytest-timeout", + "pytest-xdist", + "requests-mock", + "safetensors", + "sentencepiece", + "scipy", + "torchvision", + "transformers", +) +extras["torch"] = deps_list("torch", "accelerate") + +if os.name == "nt": # windows + extras["flax"] = [] # jax is not supported on windows +else: + extras["flax"] = deps_list("jax", "jaxlib", "flax") + +extras["dev"] = ( + extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] +) + +install_requires = [ + deps["importlib_metadata"], + deps["filelock"], + deps["huggingface-hub"], + deps["numpy"], + deps["regex"], + deps["requests"], + deps["safetensors"], + deps["Pillow"], +] + +version_range_max = max(sys.version_info[1], 10) + 1 + +setup( + name="diffusers", + version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + description="State-of-the-art diffusion in PyTorch and JAX.", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords="deep learning diffusion jax pytorch stable diffusion audioldm", + license="Apache 2.0 License", + author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)", + author_email="diffusers@huggingface.co", + url="https://github.com/huggingface/diffusers", + package_dir={"": "src"}, + packages=find_packages("src"), + package_data={"diffusers": ["py.typed"]}, + include_package_data=True, + python_requires=">=3.8.0", + install_requires=list(install_requires), + extras_require=extras, + entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]}, + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + ] + + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)], + cmdclass={"deps_table_update": DepsTableUpdateCommand}, +) diff --git a/diffusers/src/diffusers/__init__.py b/diffusers/src/diffusers/__init__.py new file mode 100755 index 0000000..41b245c --- /dev/null +++ b/diffusers/src/diffusers/__init__.py @@ -0,0 +1,964 @@ +__version__ = "0.31.0.dev0" + +from typing import TYPE_CHECKING + +from .utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_k_diffusion_available, + is_librosa_available, + is_note_seq_available, + is_onnx_available, + is_scipy_available, + is_sentencepiece_available, + is_torch_available, + is_torchsde_available, + is_transformers_available, +) + + +# Lazy Import based on +# https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py + +# When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names, +# and is used to defer the actual importing for when the objects are requested. +# This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends). + +_import_structure = { + "configuration_utils": ["ConfigMixin"], + "loaders": ["FromOriginalModelMixin"], + "models": [], + "pipelines": [], + "schedulers": [], + "utils": [ + "OptionalDependencyNotAvailable", + "is_flax_available", + "is_inflect_available", + "is_invisible_watermark_available", + "is_k_diffusion_available", + "is_k_diffusion_version", + "is_librosa_available", + "is_note_seq_available", + "is_onnx_available", + "is_scipy_available", + "is_torch_available", + "is_torchsde_available", + "is_transformers_available", + "is_transformers_version", + "is_unidecode_available", + "logging", + ], +} + +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_onnx_objects # noqa F403 + + _import_structure["utils.dummy_onnx_objects"] = [ + name for name in dir(dummy_onnx_objects) if not name.startswith("_") + ] + +else: + _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_pt_objects # noqa F403 + + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] + +else: + _import_structure["models"].extend( + [ + "AsymmetricAutoencoderKL", + "AuraFlowTransformer2DModel", + "AutoencoderKL", + "AutoencoderKLCogVideoX", + "AutoencoderKLTemporalDecoder", + "AutoencoderOobleck", + "AutoencoderTiny", + "CogVideoXTransformer3DModel", + "CogVideoXTransformer3DInpaintModel", + "CogvideoXBranchModel", + "ConsistencyDecoderVAE", + "ControlNetModel", + "ControlNetXSAdapter", + "DiTTransformer2DModel", + "FluxControlNetModel", + "FluxMultiControlNetModel", + "FluxTransformer2DModel", + "HunyuanDiT2DControlNetModel", + "HunyuanDiT2DModel", + "HunyuanDiT2DMultiControlNetModel", + "I2VGenXLUNet", + "Kandinsky3UNet", + "LatteTransformer3DModel", + "LuminaNextDiT2DModel", + "ModelMixin", + "MotionAdapter", + "MultiAdapter", + "PixArtTransformer2DModel", + "PriorTransformer", + "SD3ControlNetModel", + "SD3MultiControlNetModel", + "SD3Transformer2DModel", + "SparseControlNetModel", + "StableAudioDiTModel", + "StableCascadeUNet", + "T2IAdapter", + "FullAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "UNetControlNetXSModel", + "UNetMotionModel", + "UNetSpatioTemporalConditionModel", + "UVit2DModel", + "VQModel", + ] + ) + + _import_structure["optimization"] = [ + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + _import_structure["pipelines"].extend( + [ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + "StableDiffusionMixin", + ] + ) + _import_structure["schedulers"].extend( + [ + "AmusedScheduler", + "CMStochasticIterativeScheduler", + "CogVideoXDDIMScheduler", + "CogVideoXDPMScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EDMDPMSolverMultistepScheduler", + "EDMEulerScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "FlowMatchEulerDiscreteScheduler", + "FlowMatchHeunDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "LCMScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SASolverScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "TCDScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ] + ) + _import_structure["training_utils"] = ["EMAModel"] + +try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_scipy_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_scipy_objects"] = [ + name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") + ] + +else: + _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) + +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_torchsde_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + ] + +else: + _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"]) + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] + +else: + _import_structure["pipelines"].extend( + [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AmusedImg2ImgPipeline", + "AmusedInpaintPipeline", + "AmusedPipeline", + "AnimateDiffControlNetPipeline", + "AnimateDiffPAGPipeline", + "AnimateDiffPipeline", + "AnimateDiffSDXLPipeline", + "AnimateDiffSparseControlNetPipeline", + "AnimateDiffVideoToVideoControlNetPipeline", + "AnimateDiffVideoToVideoPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "AuraFlowPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CogVideoXImageToVideoPipeline", + "CogVideoXPipeline", + "CogVideoXVideoToVideoPipeline", + "CogVideoXInpaintPipeline", + "CogVideoXDualInpaintPipeline", + "CogVideoXSFTInpaintPipeline", + "CogVideoXSelfGuidanceInpaintPipeline", + "CogVideoXImageToVideoInpaintPipeline", + "CogVideoXI2VDualInpaintPipeline", + "CogVideoXI2VDualInpaintAnyLPipeline", + "CogVideoXI2VInpaintAnyLPipeline", + "CycleDiffusionPipeline", + "FluxControlNetImg2ImgPipeline", + "FluxControlNetInpaintPipeline", + "FluxControlNetPipeline", + "FluxImg2ImgPipeline", + "FluxInpaintPipeline", + "FluxPipeline", + "FluxFillPipeline", + "HunyuanDiTControlNetPipeline", + "HunyuanDiTPAGPipeline", + "HunyuanDiTPipeline", + "I2VGenXLPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "Kandinsky3Img2ImgPipeline", + "Kandinsky3Pipeline", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LatentConsistencyModelImg2ImgPipeline", + "LatentConsistencyModelPipeline", + "LattePipeline", + "LDMTextToImagePipeline", + "LEditsPPPipelineStableDiffusion", + "LEditsPPPipelineStableDiffusionXL", + "LuminaText2ImgPipeline", + "MarigoldDepthPipeline", + "MarigoldNormalsPipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "PIAPipeline", + "PixArtAlphaPipeline", + "PixArtSigmaPAGPipeline", + "PixArtSigmaPipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableAudioPipeline", + "StableAudioProjectionModel", + "StableCascadeCombinedPipeline", + "StableCascadeDecoderPipeline", + "StableCascadePriorPipeline", + "StableDiffusion3ControlNetInpaintingPipeline", + "StableDiffusion3ControlNetPipeline", + "StableDiffusion3Img2ImgPipeline", + "StableDiffusion3InpaintPipeline", + "StableDiffusion3PAGPipeline", + "StableDiffusion3Pipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPAGPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionControlNetXSPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPAGPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPAGImg2ImgPipeline", + "StableDiffusionXLControlNetPAGPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLControlNetXSPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPAGImg2ImgPipeline", + "StableDiffusionXLPAGInpaintPipeline", + "StableDiffusionXLPAGPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "StableVideoDiffusionPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "TextToVideoZeroSDXLPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] + ) + +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + ] + +else: + _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"]) + +try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_") + ] + +else: + _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"]) + +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ + name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + ] + +else: + _import_structure["pipelines"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) + +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_librosa_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_librosa_objects"] = [ + name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + ] + +else: + _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ + name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + ] + + +else: + _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_flax_objects # noqa F403 + + _import_structure["utils.dummy_flax_objects"] = [ + name for name in dir(dummy_flax_objects) if not name.startswith("_") + ] + + +else: + _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] + _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] + _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) + _import_structure["schedulers"].extend( + [ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ] + ) + + +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_flax_and_transformers_objects # noqa F403 + + _import_structure["utils.dummy_flax_and_transformers_objects"] = [ + name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + ] + + +else: + _import_structure["pipelines"].extend( + [ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ] + ) + +try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_note_seq_objects # noqa F403 + + _import_structure["utils.dummy_note_seq_objects"] = [ + name for name in dir(dummy_note_seq_objects) if not name.startswith("_") + ] + + +else: + _import_structure["pipelines"].extend(["MidiProcessor"]) + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .configuration_utils import ConfigMixin + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 + else: + from .pipelines import OnnxRuntimeModel + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_pt_objects import * # noqa F403 + else: + from .models import ( + AsymmetricAutoencoderKL, + AuraFlowTransformer2DModel, + AutoencoderKL, + AutoencoderKLCogVideoX, + AutoencoderKLTemporalDecoder, + AutoencoderOobleck, + AutoencoderTiny, + CogVideoXTransformer3DModel, + CogVideoXTransformer3DInpaintModel, + CogvideoXBranchModel, + ConsistencyDecoderVAE, + ControlNetModel, + ControlNetXSAdapter, + DiTTransformer2DModel, + FluxControlNetModel, + FluxMultiControlNetModel, + FluxTransformer2DModel, + HunyuanDiT2DControlNetModel, + HunyuanDiT2DModel, + HunyuanDiT2DMultiControlNetModel, + I2VGenXLUNet, + Kandinsky3UNet, + LatteTransformer3DModel, + LuminaNextDiT2DModel, + ModelMixin, + MotionAdapter, + MultiAdapter, + PixArtTransformer2DModel, + PriorTransformer, + SD3ControlNetModel, + SD3MultiControlNetModel, + SD3Transformer2DModel, + SparseControlNetModel, + StableAudioDiTModel, + T2IAdapter, + FullAdapter, + T5FilmDecoder, + Transformer2DModel, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + UNet3DConditionModel, + UNetControlNetXSModel, + UNetMotionModel, + UNetSpatioTemporalConditionModel, + UVit2DModel, + VQModel, + ) + from .optimization import ( + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, + get_scheduler, + ) + from .pipelines import ( + AudioPipelineOutput, + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + BlipDiffusionControlNetPipeline, + BlipDiffusionPipeline, + CLIPImageProjection, + ConsistencyModelPipeline, + DanceDiffusionPipeline, + DDIMPipeline, + DDPMPipeline, + DiffusionPipeline, + DiTPipeline, + ImagePipelineOutput, + KarrasVePipeline, + LDMPipeline, + LDMSuperResolutionPipeline, + PNDMPipeline, + RePaintPipeline, + ScoreSdeVePipeline, + StableDiffusionMixin, + ) + from .schedulers import ( + AmusedScheduler, + CMStochasticIterativeScheduler, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + DDIMInverseScheduler, + DDIMParallelScheduler, + DDIMScheduler, + DDPMParallelScheduler, + DDPMScheduler, + DDPMWuerstchenScheduler, + DEISMultistepScheduler, + DPMSolverMultistepInverseScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + EDMDPMSolverMultistepScheduler, + EDMEulerScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, + FlowMatchHeunDiscreteScheduler, + HeunDiscreteScheduler, + IPNDMScheduler, + KarrasVeScheduler, + KDPM2AncestralDiscreteScheduler, + KDPM2DiscreteScheduler, + LCMScheduler, + PNDMScheduler, + RePaintScheduler, + SASolverScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, + TCDScheduler, + UnCLIPScheduler, + UniPCMultistepScheduler, + VQDiffusionScheduler, + ) + from .training_utils import EMAModel + + try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_scipy_objects import * # noqa F403 + else: + from .schedulers import LMSDiscreteScheduler + + try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 + else: + from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler + + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + AltDiffusionImg2ImgPipeline, + AltDiffusionPipeline, + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, + AnimateDiffControlNetPipeline, + AnimateDiffPAGPipeline, + AnimateDiffPipeline, + AnimateDiffSDXLPipeline, + AnimateDiffSparseControlNetPipeline, + AnimateDiffVideoToVideoControlNetPipeline, + AnimateDiffVideoToVideoPipeline, + AudioLDM2Pipeline, + AudioLDM2ProjectionModel, + AudioLDM2UNet2DConditionModel, + AudioLDMPipeline, + AuraFlowPipeline, + CLIPImageProjection, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXVideoToVideoPipeline, + CogVideoXInpaintPipeline, + CogVideoXDualInpaintPipeline, + CogVideoXSFTInpaintPipeline, + CogVideoXSelfGuidanceInpaintPipeline, + CogVideoXImageToVideoInpaintPipeline, + CogVideoXI2VDualInpaintPipeline, + CogVideoXI2VDualInpaintAnyLPipeline, + CogVideoXI2VInpaintAnyLPipeline, + CycleDiffusionPipeline, + FluxControlNetImg2ImgPipeline, + FluxControlNetInpaintPipeline, + FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, + FluxPipeline, + FluxFillPipeline, + HunyuanDiTControlNetPipeline, + HunyuanDiTPAGPipeline, + HunyuanDiTPipeline, + I2VGenXLPipeline, + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ImageTextPipelineOutput, + Kandinsky3Img2ImgPipeline, + Kandinsky3Pipeline, + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyImg2ImgPipeline, + KandinskyInpaintCombinedPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + KandinskyV22CombinedPipeline, + KandinskyV22ControlnetImg2ImgPipeline, + KandinskyV22ControlnetPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintCombinedPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorEmb2EmbPipeline, + KandinskyV22PriorPipeline, + LatentConsistencyModelImg2ImgPipeline, + LatentConsistencyModelPipeline, + LattePipeline, + LDMTextToImagePipeline, + LEditsPPPipelineStableDiffusion, + LEditsPPPipelineStableDiffusionXL, + LuminaText2ImgPipeline, + MarigoldDepthPipeline, + MarigoldNormalsPipeline, + MusicLDMPipeline, + PaintByExamplePipeline, + PIAPipeline, + PixArtAlphaPipeline, + PixArtSigmaPAGPipeline, + PixArtSigmaPipeline, + SemanticStableDiffusionPipeline, + ShapEImg2ImgPipeline, + ShapEPipeline, + StableAudioPipeline, + StableAudioProjectionModel, + StableCascadeCombinedPipeline, + StableCascadeDecoderPipeline, + StableCascadePriorPipeline, + StableDiffusion3ControlNetPipeline, + StableDiffusion3Img2ImgPipeline, + StableDiffusion3InpaintPipeline, + StableDiffusion3PAGPipeline, + StableDiffusion3Pipeline, + StableDiffusionAdapterPipeline, + StableDiffusionAttendAndExcitePipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPAGPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionControlNetXSPipeline, + StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, + StableDiffusionGLIGENPipeline, + StableDiffusionGLIGENTextImagePipeline, + StableDiffusionImageVariationPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, + StableDiffusionLDM3DPipeline, + StableDiffusionModelEditingPipeline, + StableDiffusionPAGPipeline, + StableDiffusionPanoramaPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPipeline, + StableDiffusionPipelineSafe, + StableDiffusionPix2PixZeroPipeline, + StableDiffusionSAGPipeline, + StableDiffusionUpscalePipeline, + StableDiffusionXLAdapterPipeline, + StableDiffusionXLControlNetImg2ImgPipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLControlNetPAGImg2ImgPipeline, + StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLControlNetPipeline, + StableDiffusionXLControlNetXSPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPAGImg2ImgPipeline, + StableDiffusionXLPAGInpaintPipeline, + StableDiffusionXLPAGPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + StableVideoDiffusionPipeline, + TextToVideoSDPipeline, + TextToVideoZeroPipeline, + TextToVideoZeroSDXLPipeline, + UnCLIPImageVariationPipeline, + UnCLIPPipeline, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + VideoToVideoSDPipeline, + VQDiffusionPipeline, + WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, + WuerstchenPriorPipeline, + ) + + try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 + else: + from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline + + try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403 + else: + from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline + try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 + else: + from .pipelines import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + OnnxStableDiffusionUpscalePipeline, + StableDiffusionOnnxPipeline, + ) + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_librosa_objects import * # noqa F403 + else: + from .pipelines import AudioDiffusionPipeline, Mel + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + else: + from .pipelines import SpectrogramDiffusionPipeline + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 + else: + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxEulerDiscreteScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, + ) + + try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 + else: + from .pipelines import MidiProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/diffusers/src/diffusers/callbacks.py b/diffusers/src/diffusers/callbacks.py new file mode 100755 index 0000000..3854240 --- /dev/null +++ b/diffusers/src/diffusers/callbacks.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME + + +class PipelineCallback(ConfigMixin): + """ + Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing + custom callbacks and ensures that all callbacks have a consistent interface. + + Please implement the following: + `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to + include + variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + `callback_fn`: This method defines the core functionality of your callback. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): + super().__init__() + + if (cutoff_step_ratio is None and cutoff_step_index is None) or ( + cutoff_step_ratio is not None and cutoff_step_index is not None + ): + raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") + + if cutoff_step_ratio is not None and ( + not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) + ): + raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") + + @property + def tensor_inputs(self) -> List[str]: + raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}") + + def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]: + raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}") + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + return self.callback_fn(pipeline, step_index, timestep, callback_kwargs) + + +class MultiPipelineCallbacks: + """ + This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and + provides a unified interface for calling all of them. + """ + + def __init__(self, callbacks: List[PipelineCallback]): + self.callbacks = callbacks + + @property + def tensor_inputs(self) -> List[str]: + return [input for callback in self.callbacks for input in callback.tensor_inputs] + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + """ + Calls all the callbacks in order with the given arguments and returns the final callback_kwargs. + """ + for callback in self.callbacks: + callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs) + + return callback_kwargs + + +class SDCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs + + +class SDXLCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + return callback_kwargs + + +class IPAdapterScaleCutoffCallback(PipelineCallback): + """ + Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`. + + Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step. + """ + + tensor_inputs = [] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + pipeline.set_ip_adapter_scale(0.0) + return callback_kwargs diff --git a/diffusers/src/diffusers/commands/__init__.py b/diffusers/src/diffusers/commands/__init__.py new file mode 100755 index 0000000..8208283 --- /dev/null +++ b/diffusers/src/diffusers/commands/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import ArgumentParser + + +class BaseDiffusersCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: ArgumentParser): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/diffusers/src/diffusers/commands/diffusers_cli.py b/diffusers/src/diffusers/commands/diffusers_cli.py new file mode 100755 index 0000000..f582c3b --- /dev/null +++ b/diffusers/src/diffusers/commands/diffusers_cli.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from .env import EnvironmentCommand +from .fp16_safetensors import FP16SafetensorsCommand + + +def main(): + parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") + commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") + + # Register commands + EnvironmentCommand.register_subcommand(commands_parser) + FP16SafetensorsCommand.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + service = args.func(args) + service.run() + + +if __name__ == "__main__": + main() diff --git a/diffusers/src/diffusers/commands/env.py b/diffusers/src/diffusers/commands/env.py new file mode 100755 index 0000000..d0af30b --- /dev/null +++ b/diffusers/src/diffusers/commands/env.py @@ -0,0 +1,180 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import subprocess +from argparse import ArgumentParser + +import huggingface_hub + +from .. import __version__ as version +from ..utils import ( + is_accelerate_available, + is_bitsandbytes_available, + is_flax_available, + is_google_colab, + is_peft_available, + is_safetensors_available, + is_torch_available, + is_transformers_available, + is_xformers_available, +) +from . import BaseDiffusersCLICommand + + +def info_command_factory(_): + return EnvironmentCommand() + + +class EnvironmentCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser) -> None: + download_parser = parser.add_parser("env") + download_parser.set_defaults(func=info_command_factory) + + def run(self) -> dict: + hub_version = huggingface_hub.__version__ + + safetensors_version = "not installed" + if is_safetensors_available(): + import safetensors + + safetensors_version = safetensors.__version__ + + pt_version = "not installed" + pt_cuda_available = "NA" + if is_torch_available(): + import torch + + pt_version = torch.__version__ + pt_cuda_available = torch.cuda.is_available() + + flax_version = "not installed" + jax_version = "not installed" + jaxlib_version = "not installed" + jax_backend = "NA" + if is_flax_available(): + import flax + import jax + import jaxlib + + flax_version = flax.__version__ + jax_version = jax.__version__ + jaxlib_version = jaxlib.__version__ + jax_backend = jax.lib.xla_bridge.get_backend().platform + + transformers_version = "not installed" + if is_transformers_available(): + import transformers + + transformers_version = transformers.__version__ + + accelerate_version = "not installed" + if is_accelerate_available(): + import accelerate + + accelerate_version = accelerate.__version__ + + peft_version = "not installed" + if is_peft_available(): + import peft + + peft_version = peft.__version__ + + bitsandbytes_version = "not installed" + if is_bitsandbytes_available(): + import bitsandbytes + + bitsandbytes_version = bitsandbytes.__version__ + + xformers_version = "not installed" + if is_xformers_available(): + import xformers + + xformers_version = xformers.__version__ + + platform_info = platform.platform() + + is_google_colab_str = "Yes" if is_google_colab() else "No" + + accelerator = "NA" + if platform.system() in {"Linux", "Windows"}: + try: + sp = subprocess.Popen( + ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out_str, _ = sp.communicate() + out_str = out_str.decode("utf-8") + + if len(out_str) > 0: + accelerator = out_str.strip() + except FileNotFoundError: + pass + elif platform.system() == "Darwin": # Mac OS + try: + sp = subprocess.Popen( + ["system_profiler", "SPDisplaysDataType"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out_str, _ = sp.communicate() + out_str = out_str.decode("utf-8") + + start = out_str.find("Chipset Model:") + if start != -1: + start += len("Chipset Model:") + end = out_str.find("\n", start) + accelerator = out_str[start:end].strip() + + start = out_str.find("VRAM (Total):") + if start != -1: + start += len("VRAM (Total):") + end = out_str.find("\n", start) + accelerator += " VRAM: " + out_str[start:end].strip() + except FileNotFoundError: + pass + else: + print("It seems you are running an unusual OS. Could you fill in the accelerator manually?") + + info = { + "🤗 Diffusers version": version, + "Platform": platform_info, + "Running on Google Colab?": is_google_colab_str, + "Python version": platform.python_version(), + "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})", + "Jax version": jax_version, + "JaxLib version": jaxlib_version, + "Huggingface_hub version": hub_version, + "Transformers version": transformers_version, + "Accelerate version": accelerate_version, + "PEFT version": peft_version, + "Bitsandbytes version": bitsandbytes_version, + "Safetensors version": safetensors_version, + "xFormers version": xformers_version, + "Accelerator": accelerator, + "Using GPU in script?": "", + "Using distributed or parallel set-up in script?": "", + } + + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") + print(self.format_dict(info)) + + return info + + @staticmethod + def format_dict(d: dict) -> str: + return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" diff --git a/diffusers/src/diffusers/commands/fp16_safetensors.py b/diffusers/src/diffusers/commands/fp16_safetensors.py new file mode 100755 index 0000000..b26b881 --- /dev/null +++ b/diffusers/src/diffusers/commands/fp16_safetensors.py @@ -0,0 +1,132 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors +""" + +import glob +import json +import warnings +from argparse import ArgumentParser, Namespace +from importlib import import_module + +import huggingface_hub +import torch +from huggingface_hub import hf_hub_download +from packaging import version + +from ..utils import logging +from . import BaseDiffusersCLICommand + + +def conversion_command_factory(args: Namespace): + if args.use_auth_token: + warnings.warn( + "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now" + " handled automatically if user is logged in." + ) + return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors) + + +class FP16SafetensorsCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + conversion_parser = parser.add_parser("fp16_safetensors") + conversion_parser.add_argument( + "--ckpt_id", + type=str, + help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.", + ) + conversion_parser.add_argument( + "--fp16", action="store_true", help="If serializing the variables in FP16 precision." + ) + conversion_parser.add_argument( + "--use_safetensors", action="store_true", help="If serializing in the safetensors format." + ) + conversion_parser.add_argument( + "--use_auth_token", + action="store_true", + help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.", + ) + conversion_parser.set_defaults(func=conversion_command_factory) + + def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool): + self.logger = logging.get_logger("diffusers-cli/fp16_safetensors") + self.ckpt_id = ckpt_id + self.local_ckpt_dir = f"/tmp/{ckpt_id}" + self.fp16 = fp16 + + self.use_safetensors = use_safetensors + + if not self.use_safetensors and not self.fp16: + raise NotImplementedError( + "When `use_safetensors` and `fp16` both are False, then this command is of no use." + ) + + def run(self): + if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"): + raise ImportError( + "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub" + " installation." + ) + else: + from huggingface_hub import create_commit + from huggingface_hub._commit_api import CommitOperationAdd + + model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json") + with open(model_index, "r") as f: + pipeline_class_name = json.load(f)["_class_name"] + pipeline_class = getattr(import_module("diffusers"), pipeline_class_name) + self.logger.info(f"Pipeline class imported: {pipeline_class_name}.") + + # Load the appropriate pipeline. We could have use `DiffusionPipeline` + # here, but just to avoid any rough edge cases. + pipeline = pipeline_class.from_pretrained( + self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32 + ) + pipeline.save_pretrained( + self.local_ckpt_dir, + safe_serialization=True if self.use_safetensors else False, + variant="fp16" if self.fp16 else None, + ) + self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.") + + # Fetch all the paths. + if self.fp16: + modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*") + elif self.use_safetensors: + modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors") + + # Prepare for the PR. + commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}." + operations = [] + for path in modified_paths: + operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path)) + + # Open the PR. + commit_description = ( + "Variables converted by the [`diffusers`' `fp16_safetensors`" + " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)." + ) + hub_pr_url = create_commit( + repo_id=self.ckpt_id, + operations=operations, + commit_message=commit_message, + commit_description=commit_description, + repo_type="model", + create_pr=True, + ).pr_url + self.logger.info(f"PR created here: {hub_pr_url}.") diff --git a/diffusers/src/diffusers/configuration_utils.py b/diffusers/src/diffusers/configuration_utils.py new file mode 100755 index 0000000..3dccd78 --- /dev/null +++ b/diffusers/src/diffusers/configuration_utils.py @@ -0,0 +1,720 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ConfigMixin base class and utilities.""" + +import dataclasses +import functools +import importlib +import inspect +import json +import os +import re +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, Tuple, Union + +import numpy as np +from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + validate_hf_hub_args, +) +from requests import HTTPError + +from . import __version__ +from .utils import ( + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + DummyObject, + deprecate, + extract_commit_hash, + http_user_agent, + logging, +) + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +class ConfigMixin: + r""" + Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also + provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and + saving classes that inherit from [`ConfigMixin`]. + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + + config_name = None + ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 + + This function is mostly copied from PyTorch's __getattr__ overwrite: + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file is saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a config dictionary. + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class is instantiated. Make sure to only load configuration + files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it is loaded) and initiate the Python class. + `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually + overwrite the same named arguments in `config`. + + Returns: + [`ModelMixin`] or [`SchedulerMixin`]: + A model or scheduler object instantiated from a config dictionary. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + # update _class_name + if "_class_name" in hidden_dict: + hidden_dict["_class_name"] = cls.__name__ + + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + @validate_hf_hub_args + def load_config( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Load a model or scheduler configuration. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with + [`~ConfigMixin.save_config`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False): + Whether unused keyword arguments of the config are returned. + return_commit_hash (`bool`, *optional*, defaults to `False): + Whether the `commit_hash` of the loaded configuration are returned. + + Returns: + `dict`: + A dictionary of all the parameters stored in a JSON configuration file. + + """ + cache_dir = kwargs.pop("cache_dir", None) + local_dir = kwargs.pop("local_dir", None) + local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + user_agent = kwargs.pop("user_agent", {}) + + user_agent = {**user_agent, "file_type": "config"} + user_agent = http_user_agent(user_agent) + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + ) + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `token` or log in with `huggingface-cli login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + + commit_hash = extract_commit_hash(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + if not (return_unused_kwargs or return_commit_hash): + return config_dict + + outputs = (config_dict,) + + if return_unused_kwargs: + outputs += (kwargs,) + + if return_commit_hash: + outputs += (commit_hash,) + + return outputs + + @staticmethod + def _get_init_keys(input_class): + return set(dict(inspect.signature(input_class.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + + # 0. Copy origin config dict + original_dict = dict(config_dict.items()) + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove flax internal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + + # 2. Remove attributes that cannot be expected from expected config attributes + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} + + # remove attributes from orig class that cannot be expected + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if ( + isinstance(orig_cls_name, str) + and orig_cls_name != cls.__name__ + and hasattr(diffusers_library, orig_cls_name) + ): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} + elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)): + raise ValueError( + "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)." + ) + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments + init_dict = {} + for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + # 4. Give nice warning if unexpected values have been passed + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, " + "but are not expected and will be ignored. Please verify your " + f"{cls.config_name} configuration file." + ) + + # 5. Give nice info if config attributes are initialized to default because they have not been passed + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes the configuration instance to a JSON string. + + Returns: + `str`: + String containing all the attributes that make up the configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, Path): + value = value.as_posix() + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + # Don't save "_ignore_files" or "_use_default_values" + config_dict.pop("_ignore_files", None) + config_dict.pop("_use_default_values", None) + + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save the configuration instance's parameters to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file to save a configuration instance's parameters. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = dict(kwargs.items()) + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + # ignore flax specific attributes + if field.name in self._flax_internal_args: + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + # dtype should be part of `init_kwargs`, but not `new_kwargs` + if "dtype" in new_kwargs: + new_kwargs.pop("dtype") + + # Get positional arguments aligned with kwargs + for i, arg in enumerate(args): + name = fields[i].name + new_kwargs[name] = arg + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + getattr(self, "register_to_config")(**new_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = init + return cls + + +class LegacyConfigMixin(ConfigMixin): + r""" + A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + # To prevent dependency import problem. + from .models.model_loading_utils import _fetch_remapped_cls_from_config + + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_config(config, return_unused_kwargs, **kwargs) diff --git a/diffusers/src/diffusers/dependency_versions_check.py b/diffusers/src/diffusers/dependency_versions_check.py new file mode 100755 index 0000000..0728b3a --- /dev/null +++ b/diffusers/src/diffusers/dependency_versions_check.py @@ -0,0 +1,34 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dependency_versions_table import deps +from .utils.versions import require_version, require_version_core + + +# define which module versions we always want to check at run time +# (usually the ones defined in `install_requires` in setup.py) +# +# order specific notes: +# - tqdm must be checked before tokenizers + +pkgs_to_check_at_runtime = "python requests filelock numpy".split() +for pkg in pkgs_to_check_at_runtime: + if pkg in deps: + require_version_core(deps[pkg]) + else: + raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") + + +def dep_version_check(pkg, hint=None): + require_version(deps[pkg], hint) diff --git a/diffusers/src/diffusers/dependency_versions_table.py b/diffusers/src/diffusers/dependency_versions_table.py new file mode 100755 index 0000000..9e7bf24 --- /dev/null +++ b/diffusers/src/diffusers/dependency_versions_table.py @@ -0,0 +1,46 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update` +deps = { + "Pillow": "Pillow", + "accelerate": "accelerate>=0.31.0", + "compel": "compel==0.1.8", + "datasets": "datasets", + "filelock": "filelock", + "flax": "flax>=0.4.1", + "hf-doc-builder": "hf-doc-builder>=0.3.0", + "huggingface-hub": "huggingface-hub>=0.23.2", + "requests-mock": "requests-mock==1.10.0", + "importlib_metadata": "importlib_metadata", + "invisible-watermark": "invisible-watermark>=0.2.0", + "isort": "isort>=5.5.4", + "jax": "jax>=0.4.1", + "jaxlib": "jaxlib>=0.4.1", + "Jinja2": "Jinja2", + "k-diffusion": "k-diffusion>=0.0.12", + "torchsde": "torchsde", + "note_seq": "note_seq", + "librosa": "librosa", + "numpy": "numpy", + "parameterized": "parameterized", + "peft": "peft>=0.6.0", + "protobuf": "protobuf>=3.20.3,<4", + "pytest": "pytest", + "pytest-timeout": "pytest-timeout", + "pytest-xdist": "pytest-xdist", + "python": "python>=3.8.0", + "ruff": "ruff==0.1.5", + "safetensors": "safetensors>=0.3.1", + "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", + "GitPython": "GitPython<3.1.19", + "scipy": "scipy", + "onnx": "onnx", + "regex": "regex!=2019.12.17", + "requests": "requests", + "tensorboard": "tensorboard", + "torch": "torch>=1.4", + "torchvision": "torchvision", + "transformers": "transformers>=4.41.2", + "urllib3": "urllib3<=2.0.0", + "black": "black", +} diff --git a/diffusers/src/diffusers/experimental/README.md b/diffusers/src/diffusers/experimental/README.md new file mode 100755 index 0000000..81a9de8 --- /dev/null +++ b/diffusers/src/diffusers/experimental/README.md @@ -0,0 +1,5 @@ +# 🧨 Diffusers Experimental + +We are adding experimental code to support novel applications and usages of the Diffusers library. +Currently, the following experiments are supported: +* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. \ No newline at end of file diff --git a/diffusers/src/diffusers/experimental/__init__.py b/diffusers/src/diffusers/experimental/__init__.py new file mode 100755 index 0000000..ebc8155 --- /dev/null +++ b/diffusers/src/diffusers/experimental/__init__.py @@ -0,0 +1 @@ +from .rl import ValueGuidedRLPipeline diff --git a/diffusers/src/diffusers/experimental/rl/__init__.py b/diffusers/src/diffusers/experimental/rl/__init__.py new file mode 100755 index 0000000..7b338d3 --- /dev/null +++ b/diffusers/src/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/diffusers/src/diffusers/experimental/rl/value_guided_sampling.py b/diffusers/src/diffusers/experimental/rl/value_guided_sampling.py new file mode 100755 index 0000000..2f9de85 --- /dev/null +++ b/diffusers/src/diffusers/experimental/rl/value_guided_sampling.py @@ -0,0 +1,153 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import tqdm + +from ...models.unets.unet_1d import UNet1DModel +from ...pipelines import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler +from ...utils.torch_utils import randn_tensor + + +class ValueGuidedRLPipeline(DiffusionPipeline): + r""" + Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + value_function ([`UNet1DModel`]): + A specialized UNet for fine-tuning trajectories base on reward. + unet ([`UNet1DModel`]): + UNet architecture to denoise the encoded trajectories. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this + application is [`DDPMScheduler`]. + env (): + An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. + """ + + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + + self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env) + + self.data = env.get_dataset() + self.means = {} + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: # noqa: E722 + pass + self.stds = {} + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: # noqa: E722 + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if isinstance(x_in, dict): + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + + # permute to match dimension for pre-trained models + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + + # TODO: verify deprecation of this kwarg + x = self.scheduler.step(prev_x, i, x)["prev_sample"] + + # apply conditions to the trajectory (set the initial state) + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) + x1 = randn_tensor(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + + # run the diffusion process + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + + denorm_actions = denorm_actions[selected_index, 0] + return denorm_actions diff --git a/diffusers/src/diffusers/image_processor.py b/diffusers/src/diffusers/image_processor.py new file mode 100755 index 0000000..4086eda --- /dev/null +++ b/diffusers/src/diffusers/image_processor.py @@ -0,0 +1,1314 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from PIL import Image, ImageFilter, ImageOps + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate + + +PipelineImageInput = Union[ + PIL.Image.Image, + np.ndarray, + torch.Tensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.Tensor], +] + +PipelineDepthInput = PipelineImageInput + + +def is_valid_image(image) -> bool: + r""" + Checks if the input is a valid image. + + A valid image can be: + - A `PIL.Image.Image`. + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid image, `False` otherwise. + """ + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + r""" + Checks if the input is a valid image or list of images. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of images). + - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or + `torch.Tensor`. + - A list of valid images. + + Args: + images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): + The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid + images. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + + +class VaeImageProcessor(ConfigMixin): + """ + Image processor for VAE. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + vae_latent_channels: int = 4, + resample: str = "lanczos", + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: + r""" + Convert a numpy image or a batch of images to a PIL image. + + Args: + images (`np.ndarray`): + The image array to convert to PIL format. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + @staticmethod + def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: + r""" + Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`PIL.Image.Image` or `List[PIL.Image.Image]`): + The PIL image or list of images to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. + """ + if not isinstance(images, list): + images = [images] + images = [np.array(image).astype(np.float32) / 255.0 for image in images] + images = np.stack(images, axis=0) + + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.Tensor: + r""" + Convert a NumPy image to a PyTorch tensor. + + Args: + images (`np.ndarray`): + The NumPy image array to convert to PyTorch format. + + Returns: + `torch.Tensor`: + A PyTorch tensor representation of the images. + """ + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: + r""" + Convert a PyTorch tensor to a NumPy image. + + Args: + images (`torch.Tensor`): + The PyTorch tensor to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + r""" + Normalize an image array to [-1,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to normalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The normalized image array. + """ + return 2.0 * images - 1.0 + + @staticmethod + def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + r""" + Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The denormalized image array. + """ + return (images * 0.5 + 0.5).clamp(0, 1) + + @staticmethod + def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: + r""" + Converts a PIL image to RGB format. + + Args: + image (`PIL.Image.Image`): + The PIL image to convert to RGB. + + Returns: + `PIL.Image.Image`: + The RGB-converted PIL image. + """ + image = image.convert("RGB") + + return image + + @staticmethod + def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: + r""" + Converts a given PIL image to grayscale. + + Args: + image (`PIL.Image.Image`): + The input image to convert. + + Returns: + `PIL.Image.Image`: + The image converted to grayscale. + """ + image = image.convert("L") + + return image + + @staticmethod + def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: + r""" + Applies Gaussian blur to an image. + + Args: + image (`PIL.Image.Image`): + The PIL image to convert to grayscale. + + Returns: + `PIL.Image.Image`: + The grayscale-converted PIL image. + """ + image = image.filter(ImageFilter.GaussianBlur(blur_factor)) + + return image + + @staticmethod + def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): + r""" + Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect + ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for + processing are 512x512, the region will be expanded to 128x128. + + Args: + mask_image (PIL.Image.Image): Mask image. + width (int): Width of the image to be processed. + height (int): Height of the image to be processed. + pad (int, optional): Padding to be added to the crop region. Defaults to 0. + + Returns: + tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and + matches the original aspect ratio. + """ + + mask_image = mask_image.convert("L") + mask = np.array(mask_image) + + # 1. find a rectangular region that contains all masked ares in an image + h, w = mask.shape + crop_left = 0 + for i in range(w): + if not (mask[:, i] == 0).all(): + break + crop_left += 1 + + crop_right = 0 + for i in reversed(range(w)): + if not (mask[:, i] == 0).all(): + break + crop_right += 1 + + crop_top = 0 + for i in range(h): + if not (mask[i] == 0).all(): + break + crop_top += 1 + + crop_bottom = 0 + for i in reversed(range(h)): + if not (mask[i] == 0).all(): + break + crop_bottom += 1 + + # 2. add padding to the crop region + x1, y1, x2, y2 = ( + int(max(crop_left - pad, 0)), + int(max(crop_top - pad, 0)), + int(min(w - crop_right + pad, w)), + int(min(h - crop_bottom + pad, h)), + ) + + # 3. expands crop region to match the aspect ratio of the image to be processed + ratio_crop_region = (x2 - x1) / (y2 - y1) + ratio_processing = width / height + + if ratio_crop_region > ratio_processing: + desired_height = (x2 - x1) / ratio_processing + desired_height_diff = int(desired_height - (y2 - y1)) + y1 -= desired_height_diff // 2 + y2 += desired_height_diff - desired_height_diff // 2 + if y2 >= mask_image.height: + diff = y2 - mask_image.height + y2 -= diff + y1 -= diff + if y1 < 0: + y2 -= y1 + y1 -= y1 + if y2 >= mask_image.height: + y2 = mask_image.height + else: + desired_width = (y2 - y1) * ratio_processing + desired_width_diff = int(desired_width - (x2 - x1)) + x1 -= desired_width_diff // 2 + x2 += desired_width_diff - desired_width_diff // 2 + if x2 >= mask_image.width: + diff = x2 - mask_image.width + x2 -= diff + x1 -= diff + if x1 < 0: + x2 -= x1 + x1 -= x1 + if x2 >= mask_image.width: + x2 = mask_image.width + + return x1, y1, x2, y2 + + def _resize_and_fill( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + r""" + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, filling empty with data from image. + + Args: + image (`PIL.Image.Image`): + The image to resize and fill. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and filled image. + """ + + ratio = width / height + src_ratio = image.width / image.height + + src_w = width if ratio < src_ratio else image.width * height // image.height + src_h = height if ratio >= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + if fill_height > 0: + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste( + resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), + box=(0, fill_height + src_h), + ) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + if fill_width > 0: + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste( + resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), + box=(fill_width + src_w, 0), + ) + + return res + + def _resize_and_crop( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + r""" + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, cropping the excess. + + Args: + image (`PIL.Image.Image`): + The image to resize and crop. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and cropped image. + """ + ratio = width / height + src_ratio = image.width / image.height + + src_w = width if ratio > src_ratio else image.width * height // image.height + src_h = height if ratio <= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + return res + + def resize( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: int, + width: int, + resize_mode: str = "default", # "default", "fill", "crop" + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: + """ + Resize image. + + Args: + image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): + The image input, can be a PIL image, numpy array or pytorch tensor. + height (`int`): + The height to resize to. + width (`int`): + The width to resize to. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit + within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, + will resize the image to fit within the specified width and height, maintaining the aspect ratio, and + then center the image within the dimensions, filling empty with data from image. If `crop`, will resize + the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + The resized image. + """ + if resize_mode != "default" and not isinstance(image, PIL.Image.Image): + raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}") + if isinstance(image, PIL.Image.Image): + if resize_mode == "default": + image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) + elif resize_mode == "fill": + image = self._resize_and_fill(image, width, height) + elif resize_mode == "crop": + image = self._resize_and_crop(image, width, height) + else: + raise ValueError(f"resize_mode {resize_mode} is not supported") + + elif isinstance(image, torch.Tensor): + image = torch.nn.functional.interpolate( + image, + size=(height, width), + ) + elif isinstance(image, np.ndarray): + image = self.numpy_to_pt(image) + image = torch.nn.functional.interpolate( + image, + size=(height, width), + ) + image = self.pt_to_numpy(image) + return image + + def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: + """ + Create a mask. + + Args: + image (`PIL.Image.Image`): + The image input, should be a PIL image. + + Returns: + `PIL.Image.Image`: + The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1. + """ + image[image < 0.5] = 0 + image[image >= 0.5] = 1 + + return image + + def _denormalize_conditionally( + self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None + ) -> torch.Tensor: + r""" + Denormalize a batch of images based on a condition list. + + Args: + images (`torch.Tensor`): + The input image tensor. + do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): + A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the + value of `do_normalize` in the `VaeImageProcessor` config. + """ + if do_denormalize is None: + return self.denormalize(images) if self.config.do_normalize else images + + return torch.stack( + [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])] + ) + + def get_default_height_width( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + ) -> Tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + width, height = ( + x - x % self.config.vae_scale_factor for x in (width, height) + ) # resize to integer multiple of vae_scale_factor + + return height, width + + def preprocess( + self, + image: PipelineImageInput, + height: Optional[int] = None, + width: Optional[int] = None, + resize_mode: str = "default", # "default", "fill", "crop" + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. + + Args: + image (`PipelineImageInput`): + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of + supported formats. + height (`int`, *optional*): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. + width (`int`, *optional*): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + `torch.Tensor`: + The preprocessed image. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: + if isinstance(image, torch.Tensor): + # if image is a pytorch tensor could have 2 possible shapes: + # 1. batch x height x width: we should insert the channel dimension at position 1 + # 2. channel x height x width: we should insert batch dimension at position 0, + # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 + # for simplicity, we insert a dimension of size 1 at position 1 for both cases + image = image.unsqueeze(1) + else: + # if it is a numpy array, it could have 2 possible shapes: + # 1. batch x height x width: insert channel dimension on last position + # 2. height x width x channel: insert batch dimension on first position + if image.shape[-1] == 1: + image = np.expand_dims(image, axis=0) + else: + image = np.expand_dims(image, axis=-1) + + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + if not is_valid_image_imagelist(image): + raise ValueError( + f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" + ) + if not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + if self.config.do_resize: + height, width = self.get_default_height_width(image[0], height, width) + image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image] + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] + elif self.config.do_convert_grayscale: + image = [self.convert_to_grayscale(i) for i in image] + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + + image = self.numpy_to_pt(image) + + height, width = self.get_default_height_width(image, height, width) + if self.config.do_resize: + image = self.resize(image, height, width) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if self.config.do_convert_grayscale and image.ndim == 3: + image = image.unsqueeze(1) + + channel = image.shape[1] + # don't need any preprocess if the image is latents + if channel == self.config.vae_latent_channels: + return image + + height, width = self.get_default_height_width(image, height, width) + if self.config.do_resize: + image = self.resize(image, height, width) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if do_normalize and image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + if do_normalize: + image = self.normalize(image) + + if self.config.do_binarize: + image = self.binarize(image) + + return image + + def postprocess( + self, + image: torch.Tensor, + output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: + """ + Postprocess the image output from tensor to `output_type`. + + Args: + image (`torch.Tensor`): + The image input, should be a pytorch tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`List[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + The postprocessed image. + """ + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + return image + + image = self._denormalize_conditionally(image, do_denormalize) + + if output_type == "pt": + return image + + image = self.pt_to_numpy(image) + + if output_type == "np": + return image + + if output_type == "pil": + return self.numpy_to_pil(image) + + def apply_overlay( + self, + mask: PIL.Image.Image, + init_image: PIL.Image.Image, + image: PIL.Image.Image, + crop_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> PIL.Image.Image: + r""" + Applies an overlay of the mask and the inpainted image on the original image. + + Args: + mask (`PIL.Image.Image`): + The mask image that highlights regions to overlay. + init_image (`PIL.Image.Image`): + The original image to which the overlay is applied. + image (`PIL.Image.Image`): + The image to overlay onto the original. + crop_coords (`Tuple[int, int, int, int]`, *optional*): + Coordinates to crop the image. If provided, the image will be cropped accordingly. + + Returns: + `PIL.Image.Image`: + The final image with the overlay applied. + """ + + width, height = init_image.width, init_image.height + + init_image_masked = PIL.Image.new("RGBa", (width, height)) + init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L"))) + + init_image_masked = init_image_masked.convert("RGBA") + + if crop_coords is not None: + x, y, x2, y2 = crop_coords + w = x2 - x + h = y2 - y + base_image = PIL.Image.new("RGBA", (width, height)) + image = self.resize(image, height=h, width=w, resize_mode="crop") + base_image.paste(image, (x, y)) + image = base_image.convert("RGB") + + image = image.convert("RGBA") + image.alpha_composite(init_image_masked) + image = image.convert("RGB") + + return image + + +class VaeImageProcessorLDM3D(VaeImageProcessor): + """ + Image processor for VAE LDM3D. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + ): + super().__init__() + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: + r""" + Convert a NumPy image or a batch of images to a list of PIL images. + + Args: + images (`np.ndarray`): + The input NumPy array of images, which can be a single image or a batch. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images converted from the input NumPy array. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image[:, :, :3]) for image in images] + + return pil_images + + @staticmethod + def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: + r""" + Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`Union[List[PIL.Image.Image], PIL.Image.Image]`): + The input image or list of images to be converted. + + Returns: + `np.ndarray`: + A NumPy array of the converted images. + """ + if not isinstance(images, list): + images = [images] + + images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images] + images = np.stack(images, axis=0) + return images + + @staticmethod + def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + r""" + Convert an RGB-like depth image to a depth map. + + Args: + image (`Union[np.ndarray, torch.Tensor]`): + The RGB-like depth image to convert. + + Returns: + `Union[np.ndarray, torch.Tensor]`: + The corresponding depth map. + """ + return image[:, :, 1] * 2**8 + image[:, :, 2] + + def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: + r""" + Convert a NumPy depth image or a batch of images to a list of PIL images. + + Args: + images (`np.ndarray`): + The input NumPy array of depth images, which can be a single image or a batch. + + Returns: + `List[PIL.Image.Image]`: + A list of PIL images converted from the input NumPy depth images. + """ + if images.ndim == 3: + images = images[None, ...] + images_depth = images[:, :, :, 3:] + if images.shape[-1] == 6: + images_depth = (images_depth * 255).round().astype("uint8") + pil_images = [ + Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth + ] + elif images.shape[-1] == 4: + images_depth = (images_depth * 65535.0).astype(np.uint16) + pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth] + else: + raise Exception("Not supported") + + return pil_images + + def postprocess( + self, + image: torch.Tensor, + output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: + """ + Postprocess the image output from tensor to `output_type`. + + Args: + image (`torch.Tensor`): + The image input, should be a pytorch tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`List[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + The postprocessed image. + """ + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + image = self._denormalize_conditionally(image, do_denormalize) + + image = self.pt_to_numpy(image) + + if output_type == "np": + if image.shape[-1] == 6: + image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) + else: + image_depth = image[:, :, :, 3:] + return image[:, :, :, :3], image_depth + + if output_type == "pil": + return self.numpy_to_pil(image), self.numpy_to_depth(image) + else: + raise Exception(f"This type {output_type} is not supported") + + def preprocess( + self, + rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray], + depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray], + height: Optional[int] = None, + width: Optional[int] = None, + target_res: Optional[int] = None, + ) -> torch.Tensor: + r""" + Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors. + + Args: + rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`): + The RGB input image, which can be a single image or a batch. + depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`): + The depth input image, which can be a single image or a batch. + height (`Optional[int]`, *optional*, defaults to `None`): + The desired height of the processed image. If `None`, defaults to the height of the input image. + width (`Optional[int]`, *optional*, defaults to `None`): + The desired width of the processed image. If `None`, defaults to the width of the input image. + target_res (`Optional[int]`, *optional*, defaults to `None`): + Target resolution for resizing the images. If specified, overrides height and width. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing the processed RGB and depth images as PyTorch tensors. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3: + raise Exception("This is not yet supported") + + if isinstance(rgb, supported_formats): + rgb = [rgb] + depth = [depth] + elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(rgb[0], PIL.Image.Image): + if self.config.do_convert_rgb: + raise Exception("This is not yet supported") + # rgb = [self.convert_to_rgb(i) for i in rgb] + # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth + if self.config.do_resize or target_res: + height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res + rgb = [self.resize(i, height, width) for i in rgb] + depth = [self.resize(i, height, width) for i in depth] + rgb = self.pil_to_numpy(rgb) # to np + rgb = self.numpy_to_pt(rgb) # to pt + + depth = self.depth_pil_to_numpy(depth) # to np + depth = self.numpy_to_pt(depth) # to pt + + elif isinstance(rgb[0], np.ndarray): + rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0) + rgb = self.numpy_to_pt(rgb) + height, width = self.get_default_height_width(rgb, height, width) + if self.config.do_resize: + rgb = self.resize(rgb, height, width) + + depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0) + depth = self.numpy_to_pt(depth) + height, width = self.get_default_height_width(depth, height, width) + if self.config.do_resize: + depth = self.resize(depth, height, width) + + elif isinstance(rgb[0], torch.Tensor): + raise Exception("This is not yet supported") + # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0) + + # if self.config.do_convert_grayscale and rgb.ndim == 3: + # rgb = rgb.unsqueeze(1) + + # channel = rgb.shape[1] + + # height, width = self.get_default_height_width(rgb, height, width) + # if self.config.do_resize: + # rgb = self.resize(rgb, height, width) + + # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0) + + # if self.config.do_convert_grayscale and depth.ndim == 3: + # depth = depth.unsqueeze(1) + + # channel = depth.shape[1] + # # don't need any preprocess if the image is latents + # if depth == 4: + # return rgb, depth + + # height, width = self.get_default_height_width(depth, height, width) + # if self.config.do_resize: + # depth = self.resize(depth, height, width) + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if rgb.min() < 0 and do_normalize: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]", + FutureWarning, + ) + do_normalize = False + + if do_normalize: + rgb = self.normalize(rgb) + depth = self.normalize(depth) + + if self.config.do_binarize: + rgb = self.binarize(rgb) + depth = self.binarize(depth) + + return rgb, depth + + +class IPAdapterMaskProcessor(VaeImageProcessor): + """ + Image processor for IP Adapter image masks. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the image to 0/1. + do_convert_grayscale (`bool`, *optional*, defaults to be `True`): + Whether to convert the images to grayscale format. + + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = False, + do_binarize: bool = True, + do_convert_grayscale: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + @staticmethod + def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): + """ + Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the + aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. + + Args: + mask (`torch.Tensor`): + The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. + batch_size (`int`): + The batch size. + num_queries (`int`): + The number of queries. + value_embed_dim (`int`): + The dimensionality of the value embeddings. + + Returns: + `torch.Tensor`: + The downsampled mask tensor. + + """ + o_h = mask.shape[1] + o_w = mask.shape[2] + ratio = o_w / o_h + mask_h = int(math.sqrt(num_queries / ratio)) + mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) + mask_w = num_queries // mask_h + + mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) + + # Repeat batch_size times + if mask_downsample.shape[0] < batch_size: + mask_downsample = mask_downsample.repeat(batch_size, 1, 1) + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) + + downsampled_area = mask_h * mask_w + # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match + # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries + if downsampled_area < num_queries: + warnings.warn( + "The aspect ratio of the mask does not match the aspect ratio of the output image. " + "Please update your masks or adjust the output size for optimal performance.", + UserWarning, + ) + mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0) + # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries + if downsampled_area > num_queries: + warnings.warn( + "The aspect ratio of the mask does not match the aspect ratio of the output image. " + "Please update your masks or adjust the output size for optimal performance.", + UserWarning, + ) + mask_downsample = mask_downsample[:, :num_queries] + + # Repeat last dimension to match SDPA output shape + mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( + 1, 1, value_embed_dim + ) + + return mask_downsample + + +class PixArtImageProcessor(VaeImageProcessor): + """ + Image processor for PixArt image resize and crop. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + @staticmethod + def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: + r""" + Returns the binned height and width based on the aspect ratio. + + Args: + height (`int`): The height of the image. + width (`int`): The width of the image. + ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width). + + Returns: + `Tuple[int, int]`: The closest binned height and width. + """ + ar = float(height / width) + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + default_hw = ratios[closest_ratio] + return int(default_hw[0]), int(default_hw[1]) + + @staticmethod + def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: + r""" + Resizes and crops a tensor of images to the specified dimensions. + + Args: + samples (`torch.Tensor`): + A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height, + and W is the width. + new_width (`int`): The desired width of the output images. + new_height (`int`): The desired height of the output images. + + Returns: + `torch.Tensor`: A tensor containing the resized and cropped images. + """ + orig_height, orig_width = samples.shape[2], samples.shape[3] + + # Check if resizing is needed + if orig_height != new_height or orig_width != new_width: + ratio = max(new_height / orig_height, new_width / orig_width) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + # Resize + samples = F.interpolate( + samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + # Center Crop + start_x = (resized_width - new_width) // 2 + end_x = start_x + new_width + start_y = (resized_height - new_height) // 2 + end_y = start_y + new_height + samples = samples[:, :, start_y:end_y, start_x:end_x] + + return samples \ No newline at end of file diff --git a/diffusers/src/diffusers/loaders/__init__.py b/diffusers/src/diffusers/loaders/__init__.py new file mode 100755 index 0000000..bf72122 --- /dev/null +++ b/diffusers/src/diffusers/loaders/__init__.py @@ -0,0 +1,102 @@ +from typing import TYPE_CHECKING + +from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate +from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available + + +def text_encoder_lora_state_dict(text_encoder): + deprecate( + "text_encoder_load_state_dict in `models`", + "0.27.0", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", + ) + state_dict = {} + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + +if is_transformers_available(): + + def text_encoder_attn_modules(text_encoder): + deprecate( + "text_encoder_attn_modules in `models`", + "0.27.0", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", + ) + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +_import_structure = {} + +if is_torch_available(): + _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + + _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] + _import_structure["utils"] = ["AttnProcsLayers"] + if is_transformers_available(): + _import_structure["single_file"] = ["FromSingleFileMixin"] + _import_structure["lora_pipeline"] = [ + "AmusedLoraLoaderMixin", + "StableDiffusionLoraLoaderMixin", + "SD3LoraLoaderMixin", + "StableDiffusionXLLoraLoaderMixin", + "LoraLoaderMixin", + "FluxLoraLoaderMixin", + "CogVideoXLoraLoaderMixin", + ] + _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] + _import_structure["ip_adapter"] = ["IPAdapterMixin"] + +_import_structure["peft"] = ["PeftAdapterMixin"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + if is_torch_available(): + from .single_file_model import FromOriginalModelMixin + from .unet import UNet2DConditionLoadersMixin + from .utils import AttnProcsLayers + + if is_transformers_available(): + from .ip_adapter import IPAdapterMixin + from .lora_pipeline import ( + AmusedLoraLoaderMixin, + CogVideoXLoraLoaderMixin, + FluxLoraLoaderMixin, + LoraLoaderMixin, + SD3LoraLoaderMixin, + StableDiffusionLoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ) + from .single_file import FromSingleFileMixin + from .textual_inversion import TextualInversionLoaderMixin + + from .peft import PeftAdapterMixin +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/diffusers/src/diffusers/loaders/ip_adapter.py b/diffusers/src/diffusers/loaders/ip_adapter.py new file mode 100755 index 0000000..1006dab --- /dev/null +++ b/diffusers/src/diffusers/loaders/ip_adapter.py @@ -0,0 +1,348 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from huggingface_hub.utils import validate_hf_hub_args +from safetensors import safe_open + +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict +from ..utils import ( + USE_PEFT_BACKEND, + _get_model_file, + is_accelerate_available, + is_torch_version, + is_transformers_available, + logging, +) +from .unet_loader_utils import _maybe_expand_lora_scales + + +if is_transformers_available(): + from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + ) + + from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + ) + +logger = logging.get_logger(__name__) + + +class IPAdapterMixin: + """Mixin for handling IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + subfolder: Union[str, List[str]], + weight_name: Union[str, List[str]], + image_encoder_folder: Optional[str] = "image_encoder", + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str` or `List[str]`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str` or `List[str]`): + The name of the weight file to load. If a list is passed, it should have the same length as + `weight_name`. + image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): + The subfolder location of the image encoder within a larger model repository on the Hub or locally. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + state_dicts.append(state_dict) + + # load CLIP image encoder here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_folder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + if image_encoder_folder.count("/") == 0: + image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() + else: + image_encoder_subfolder = Path(image_encoder_folder).as_posix() + + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path_or_dict, + subfolder=image_encoder_subfolder, + low_cpu_mem_usage=low_cpu_mem_usage, + cache_dir=cache_dir, + local_files_only=local_files_only, + ).to(self.device, dtype=self.dtype) + self.register_modules(image_encoder=image_encoder) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 + default_clip_size = 224 + clip_image_size = ( + self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size + ) + feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) + self.register_modules(feature_extractor=feature_extractor) + + # load ip-adapter into unet + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + extra_loras = unet._load_ip_adapter_loras(state_dicts) + if extra_loras != {}: + if not USE_PEFT_BACKEND: + logger.warning("PEFT backend is required to load these weights.") + else: + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) + + def set_ip_adapter_scale(self, scale): + """ + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a dictionary. + + Example: + + ```py + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + # To use style block only + scale = { + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style+layout blocks + scale = { + "down": {"block_2": [0.0, 1.0]}, + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style and layout from 2 reference images + scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] + pipeline.set_ip_adapter_scale(scales) + ``` + """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if not isinstance(scale, list): + scale = [scale] + scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) + + for attn_name, attn_processor in unet.attn_processors.items(): + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + if len(scale_configs) != len(attn_processor.scale): + raise ValueError( + f"Cannot assign {len(scale_configs)} scale_configs to " + f"{len(attn_processor.scale)} IP-Adapter." + ) + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + if isinstance(scale_config, dict): + for k, s in scale_config.items(): + if attn_name.startswith(k): + attn_processor.scale[i] = s + else: + attn_processor.scale[i] = scale_config + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # remove CLIP image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=[None, None]) + + # remove feature extractor only when safety_checker is None as safety_checker uses + # the feature_extractor later + if not hasattr(self, "safety_checker"): + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=[None, None]) + + # remove hidden encoder + self.unet.encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = None + + # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` + if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None: + self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj + self.unet.text_encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = "text_proj" + + # restore original Unet attention processors layers + attn_procs = {} + for name, value in self.unet.attn_processors.items(): + attn_processor_class = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + ) + attn_procs[name] = ( + attn_processor_class + if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + else value.__class__() + ) + self.unet.set_attn_processor(attn_procs) diff --git a/diffusers/src/diffusers/loaders/lora_base.py b/diffusers/src/diffusers/loaders/lora_base.py new file mode 100755 index 0000000..f18b978 --- /dev/null +++ b/diffusers/src/diffusers/loaders/lora_base.py @@ -0,0 +1,753 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import safetensors +import torch +import torch.nn as nn +from huggingface_hub import model_info +from huggingface_hub.constants import HF_HUB_OFFLINE + +from ..models.modeling_utils import ModelMixin, load_state_dict +from ..utils import ( + USE_PEFT_BACKEND, + _get_model_file, + delete_adapter_layers, + deprecate, + is_accelerate_available, + is_peft_available, + is_transformers_available, + logging, + recurse_remove_peft_layers, + set_adapter_layers, + set_weights_and_activate_adapters, +) + + +if is_transformers_available(): + from transformers import PreTrainedModel + +if is_peft_available(): + from peft.tuners.tuners_utils import BaseTunerLayer + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): + """ + Fuses LoRAs for the text encoder. + + Args: + text_encoder (`torch.nn.Module`): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + """ + merge_kwargs = {"safe_merge": safe_fusing} + + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) + + # For BC with previous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. " + "Please upgrade to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) + + +def unfuse_text_encoder_lora(text_encoder): + """ + Unfuses LoRAs for the text encoder. + + Args: + text_encoder (`torch.nn.Module`): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + +def set_adapters_for_text_encoder( + adapter_names: Union[List[str], str], + text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 + text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, +): + """ + Sets the adapter layers for the text encoder. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + text_encoder_weights (`List[float]`, *optional*): + The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. + """ + if text_encoder is None: + raise ValueError( + "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." + ) + + def process_weights(adapter_names, weights): + # Expand weights into a list, one entry per adapter + # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" + ) + + # Set None values to default of 1.0 + # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] + weights = [w if w is not None else 1.0 for w in weights] + + return weights + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + text_encoder_weights = process_weights(adapter_names, text_encoder_weights) + set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) + + +def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): + """ + Disables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(text_encoder, enabled=False) + + +def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): + """ + Enables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(text_encoder, enabled=True) + + +def _remove_text_encoder_monkey_patch(text_encoder): + recurse_remove_peft_layers(text_encoder) + if getattr(text_encoder, "peft_config", None) is not None: + del text_encoder.peft_config + text_encoder._hf_peft_config_loaded = None + + +class LoraBaseMixin: + """Utility class for handling LoRAs.""" + + _lora_loadable_modules = [] + num_fused_loras = 0 + + def load_lora_weights(self, **kwargs): + raise NotImplementedError("`load_lora_weights()` is not implemented.") + + @classmethod + def save_lora_weights(cls, **kwargs): + raise NotImplementedError("`save_lora_weights()` not implemented.") + + @classmethod + def lora_state_dict(cls, **kwargs): + raise NotImplementedError("`lora_state_dict()` is not implemented.") + + @classmethod + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + @classmethod + def _fetch_state_dict( + cls, + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + proxies, + token, + revision, + subfolder, + user_agent, + allow_pickle, + ): + from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + pass + + if model_file is None: + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict + + @classmethod + def _best_guess_weight_name( + cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False + ): + from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + targeted_files = [] + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return + elif os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [ + f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) + ] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + if len(targeted_files) == 0: + return + + # "scheduler" does not correspond to a LoRA checkpoint. + # "optimizer" does not correspond to a LoRA checkpoint + # only top-level checkpoints are considered and not the other ones, hence "checkpoint". + unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} + targeted_files = list( + filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) + ) + + if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) + elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) + + if len(targeted_files) > 1: + raise ValueError( + f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." + ) + weight_name = targeted_files[0] + return weight_name + + def unload_lora_weights(self): + """ + Unloads the LoRA parameters. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.unload_lora() + elif issubclass(model.__class__, PreTrainedModel): + _remove_text_encoder_monkey_patch(model) + + def fuse_lora( + self, + components: List[str] = [], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + if "fuse_unet" in kwargs: + depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version." + deprecate( + "fuse_unet", + "1.0.0", + depr_message, + ) + if "fuse_transformer" in kwargs: + depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version." + deprecate( + "fuse_transformer", + "1.0.0", + depr_message, + ) + if "fuse_text_encoder" in kwargs: + depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version." + deprecate( + "fuse_text_encoder", + "1.0.0", + depr_message, + ) + + if len(components) == 0: + raise ValueError("`components` cannot be an empty list.") + + for fuse_component in components: + # if fuse_component not in self._lora_loadable_modules: + # raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") + + model = getattr(self, fuse_component, None) + if model is not None: + # check if diffusers model + if issubclass(model.__class__, ModelMixin): + model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + # handle transformers models. + if issubclass(model.__class__, PreTrainedModel): + fuse_text_encoder_lora( + model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + self.num_fused_loras += 1 + + def unfuse_lora(self, components: List[str] = [], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + if "unfuse_unet" in kwargs: + depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version." + deprecate( + "unfuse_unet", + "1.0.0", + depr_message, + ) + if "unfuse_transformer" in kwargs: + depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version." + deprecate( + "unfuse_transformer", + "1.0.0", + depr_message, + ) + if "unfuse_text_encoder" in kwargs: + depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version." + deprecate( + "unfuse_text_encoder", + "1.0.0", + depr_message, + ) + + if len(components) == 0: + raise ValueError("`components` cannot be an empty list.") + + for fuse_component in components: + if fuse_component not in self._lora_loadable_modules: + raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") + + model = getattr(self, fuse_component, None) + if model is not None: + if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + self.num_fused_loras -= 1 + + def set_adapters( + self, + adapter_names: Union[List[str], str], + adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, + ): + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + adapter_weights = copy.deepcopy(adapter_weights) + + # Expand weights into a list, one entry per adapter + if not isinstance(adapter_weights, list): + adapter_weights = [adapter_weights] * len(adapter_names) + + if len(adapter_names) != len(adapter_weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" + ) + + list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} + all_adapters = { + adapter for adapters in list_adapters.values() for adapter in adapters + } # eg ["adapter1", "adapter2"] + invert_list_adapters = { + adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] + for adapter in all_adapters + } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + + # Decompose weights into weights for denoiser and text encoders. + _component_adapter_weights = {} + for component in self._lora_loadable_modules: + model = getattr(self, component) + + for adapter_name, weights in zip(adapter_names, adapter_weights): + if isinstance(weights, dict): + component_adapter_weights = weights.pop(component, None) + + if component_adapter_weights is not None and not hasattr(self, component): + logger.warning( + f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}." + ) + + if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]: + logger.warning( + ( + f"Lora weight dict for adapter '{adapter_name}' contains {component}," + f"but this will be ignored because {adapter_name} does not contain weights for {component}." + f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." + ) + ) + + else: + component_adapter_weights = weights + + _component_adapter_weights.setdefault(component, []) + _component_adapter_weights[component].append(component_adapter_weights) + + if issubclass(model.__class__, ModelMixin): + model.set_adapters(adapter_names, _component_adapter_weights[component]) + elif issubclass(model.__class__, PreTrainedModel): + set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) + + def disable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.disable_lora() + elif issubclass(model.__class__, PreTrainedModel): + disable_lora_for_text_encoder(model) + + def enable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.enable_lora() + elif issubclass(model.__class__, PreTrainedModel): + enable_lora_for_text_encoder(model) + + def delete_adapters(self, adapter_names: Union[List[str], str]): + """ + Args: + Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). + adapter_names (`Union[List[str], str]`): + The names of the adapter to delete. Can be a single string or a list of strings + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.delete_adapters(adapter_names) + elif issubclass(model.__class__, PreTrainedModel): + for adapter_name in adapter_names: + delete_adapter_layers(model, adapter_name) + + def get_active_adapters(self) -> List[str]: + """ + Gets the list of the current active adapters. + + Example: + + ```python + from diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + ).to("cuda") + pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") + pipeline.get_active_adapters() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + active_adapters = [] + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None and issubclass(model.__class__, ModelMixin): + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapters + break + + return active_adapters + + def get_list_adapters(self) -> Dict[str, List[str]]: + """ + Gets the current list of all available adapters in the pipeline. + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + set_adapters = {} + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if ( + model is not None + and issubclass(model.__class__, (ModelMixin, PreTrainedModel)) + and hasattr(model, "peft_config") + ): + set_adapters[component] = list(model.peft_config.keys()) + + return set_adapters + + def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: + """ + Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case + you want to load multiple adapters and free some GPU memory. + + Args: + adapter_names (`List[str]`): + List of adapters to send device to. + device (`Union[torch.device, str, int]`): + Device to send the adapters to. Can be either a torch device, a str or an integer. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + for adapter_name in adapter_names: + module.lora_A[adapter_name].to(device) + module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None: + if adapter_name in module.lora_magnitude_vector: + module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[ + adapter_name + ].to(device) + + @staticmethod + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + @staticmethod + def write_lora_layers( + state_dict: Dict[str, torch.Tensor], + save_directory: str, + is_main_process: bool, + weight_name: str, + save_function: Callable, + safe_serialization: bool, + ): + from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + save_path = Path(save_directory, weight_name).as_posix() + save_function(state_dict, save_path) + logger.info(f"Model weights saved in {save_path}") + + @property + def lora_scale(self) -> float: + # property function that returns the lora scale which can be set at run time by the pipeline. + # if _lora_scale has not been set, return 1 + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 diff --git a/diffusers/src/diffusers/loaders/lora_conversion_utils.py b/diffusers/src/diffusers/loaders/lora_conversion_utils.py new file mode 100755 index 0000000..f6dea33 --- /dev/null +++ b/diffusers/src/diffusers/loaders/lora_conversion_utils.py @@ -0,0 +1,623 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch + +from ..utils import is_peft_version, logging + + +logger = logging.get_logger(__name__) + + +def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): + # 1. get all state_dict_keys + all_keys = list(state_dict.keys()) + sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] + + # 2. check if needs remapping, if not return original dict + is_in_sgm_format = False + for key in all_keys: + if any(p in key for p in sgm_patterns): + is_in_sgm_format = True + break + + if not is_in_sgm_format: + return state_dict + + # 3. Else remap from SGM patterns + new_state_dict = {} + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() + + for layer in all_keys: + if "text" in layer: + new_state_dict[layer] = state_dict.pop(layer) + else: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if sgm_patterns[0] in layer: + input_block_ids.add(layer_id) + elif sgm_patterns[1] in layer: + middle_block_ids.add(layer_id) + elif sgm_patterns[2] in layer: + output_block_ids.add(layer_id) + else: + raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") + + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] + for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] + for layer_id in middle_block_ids + } + output_blocks = { + layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] + for layer_id in output_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (unet_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" + inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in output_block_ids: + block_id = i // (unet_config.layers_per_block + 1) + layer_in_block_id = i % (unet_config.layers_per_block + 1) + + for key in output_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] + inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + if len(state_dict) > 0: + raise ValueError("At this point all state dict entries have to be converted.") + + return new_state_dict + + +def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): + """ + Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. + + Args: + state_dict (`dict`): The state dict to convert. + unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". + text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to + "text_encoder". + + Returns: + `tuple`: A tuple containing the converted state dict and a dictionary of alphas. + """ + unet_state_dict = {} + te_state_dict = {} + te2_state_dict = {} + network_alphas = {} + + # Check for DoRA-enabled LoRAs. + dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) + dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) + dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) + if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + + # Iterate over all LoRA weights. + all_lora_keys = list(state_dict.keys()) + for key in all_lora_keys: + if not key.endswith("lora_down.weight"): + continue + + # Extract LoRA name. + lora_name = key.split(".")[0] + + # Find corresponding up weight and alpha. + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + + # Handle U-Net LoRAs. + if lora_name.startswith("lora_unet_"): + diffusers_name = _convert_unet_lora_key(key) + + # Store down and up weights. + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Store DoRA scale if present. + if dora_present_in_unet: + dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." + unet_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + + # Handle text encoder LoRAs. + elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + # Store down and up weights for te or te2. + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Store DoRA scale if present. + if dora_present_in_te or dora_present_in_te2: + dora_scale_key_to_replace_te = ( + "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." + ) + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + elif lora_name.startswith("lora_te2_"): + te2_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + + # Store alpha if present. + if lora_name_alpha in state_dict: + alpha = state_dict.pop(lora_name_alpha).item() + network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) + + # Check if any keys remain. + if len(state_dict) > 0: + raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") + + logger.info("Non-diffusers checkpoint detected.") + + # Construct final state dict. + unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} + te2_state_dict = ( + {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} + if len(te2_state_dict) > 0 + else None + ) + if te2_state_dict is not None: + te_state_dict.update(te2_state_dict) + + new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alphas + + +def _convert_unet_lora_key(key): + """ + Converts a U-Net LoRA key to a Diffusers compatible key. + """ + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + # Replace common U-Net naming patterns. + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specific conversions. + if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + # LyCORIS specific conversions. + if "time.emb.proj" in diffusers_name: + diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") + if "conv.shortcut" in diffusers_name: + diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") + + # General conversions. + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + elif "ff" in diffusers_name: + pass + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + pass + else: + pass + + return diffusers_name + + +def _convert_text_encoder_lora_key(key, lora_name): + """ + Converts a text encoder LoRA key to a Diffusers compatible key. + """ + if lora_name.startswith(("lora_te_", "lora_te1_")): + key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + else: + key_to_replace = "lora_te2_" + + diffusers_name = key.replace(key_to_replace, "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("text.projection", "text_projection") + + if "self_attn" in diffusers_name or "text_projection" in diffusers_name: + pass + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + return diffusers_name + + +def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): + """ + Gets the correct alpha name for the Diffusers model. + """ + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + return {new_name: alpha} + + +# The utilities under `_convert_kohya_flux_lora_to_diffusers()` +# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +# All credits go to `kohya-ss`. +def _convert_kohya_flux_lora_to_diffusers(state_dict): + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 + ) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + def _convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mod_lin", + f"transformer.transformer_blocks.{i}.norm1.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mod_lin", + f"transformer.transformer_blocks.{i}.norm1_context.linear", + ) + + for i in range(38): + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.proj_out", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_modulation_lin", + f"transformer.single_transformer_blocks.{i}.norm.linear", + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + + return ait_sd + + return _convert_sd_scripts_to_ai_toolkit(state_dict) + + +# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6 +# Some utilities were reused from +# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): + new_state_dict = {} + orig_keys = list(old_state_dict.keys()) + + def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + down_weight = sds_sd.pop(sds_key) + up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + + for old_key in orig_keys: + # Handle double_blocks + if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): + block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.transformer_blocks.{block_num}" + + if "processor.proj_lora1" in old_key: + new_key += ".attn.to_out.0" + elif "processor.proj_lora2" in old_key: + new_key += ".attn.to_add_out" + # Handle text latents. + elif "processor.qkv_lora2" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", + ], + ) + # continue + # Handle image latents. + elif "processor.qkv_lora1" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.to_q", + f"transformer.transformer_blocks.{block_num}.attn.to_k", + f"transformer.transformer_blocks.{block_num}.attn.to_v", + ], + ) + # continue + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + # Handle single_blocks + elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): + block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.single_transformer_blocks.{block_num}" + + if "proj_lora1" in old_key or "proj_lora2" in old_key: + new_key += ".proj_out" + elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: + new_key += ".norm.linear" + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + else: + # Handle other potential key patterns here + new_key = old_key + + # Since we already handle qkv above. + if "qkv" not in old_key: + new_state_dict[new_key] = old_state_dict.pop(old_key) + + if len(old_state_dict) > 0: + raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") + + return new_state_dict diff --git a/diffusers/src/diffusers/loaders/lora_pipeline.py b/diffusers/src/diffusers/loaders/lora_pipeline.py new file mode 100755 index 0000000..51e1c0c --- /dev/null +++ b/diffusers/src/diffusers/loaders/lora_pipeline.py @@ -0,0 +1,2828 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, Dict, List, Optional, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args + +from ..utils import ( + USE_PEFT_BACKEND, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, + deprecate, + get_adapter_name, + get_peft_kwargs, + is_peft_version, + is_transformers_available, + logging, + scale_lora_layers, +) +from .lora_base import LoraBaseMixin +from .lora_conversion_utils import ( + _convert_kohya_flux_lora_to_diffusers, + _convert_non_diffusers_lora_to_diffusers, + _convert_xlabs_flux_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) + + +if is_transformers_available(): + from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules + +logger = logging.get_logger(__name__) + +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + + +class StableDiffusionLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + """ + + _lora_loadable_modules = ["unet", "text_encoder"] + unet_name = UNET_NAME + text_encoder_name = TEXT_ENCODER_NAME + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_unet( + state_dict, + network_alphas=network_alphas, + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + adapter_name=adapter_name, + _pipeline=self, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=getattr(self, self.text_encoder_name) + if not hasattr(self, "text_encoder") + else self.text_encoder, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas + + @classmethod + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `unet`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + # Load the layers corresponding to UNet. + logger.info(f"Loading {cls.unet_name}.") + unet.load_attn_procs( + state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + ) + + @classmethod + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (unet_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") + + if unet_lora_layers: + state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["unet", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + +class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and + [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). + """ + + _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"] + unet_name = UNET_NAME + text_encoder_name = TEXT_ENCODER_NAME + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_unet( + state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `unet`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + # Load the layers corresponding to UNet. + logger.info(f"Loading {cls.unet_name}.") + unet.load_attn_procs( + state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." + ) + + if unet_lora_layers: + state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + + if text_encoder_2_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["unet", "text_encoder", "text_encoder_2"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + +class SD3LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`SD3Transformer2DModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and + [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). + + Specific to [`StableDiffusion3Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=None, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, torch.nn.Module] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." + ) + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + + if text_encoder_2_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + +class FluxLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`FluxTransformer2DModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + Specific to [`StableDiffusion3Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "text_encoder"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + return_alphas: bool = False, + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + return (state_dict, None) if return_alphas else state_dict + + is_xlabs = any("processor" in k for k in state_dict) + if is_xlabs: + state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) + # xlabs doesn't use `alpha`. + return (state_dict, None) if return_alphas else state_dict + + is_bfl_control = any("query_norm.scale" in k for k in state_dict) + if is_bfl_control: + state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) + return (state_dict, None) if return_alphas else state_dict + + # For state dicts like + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA + keys = list(state_dict.keys()) + network_alphas = {} + for k in keys: + if "alpha" in k: + alpha_value = state_dict.get(k) + if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( + alpha_value, float + ): + network_alphas[k] = state_dict.pop(k) + else: + raise ValueError( + f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." + ) + + if return_alphas: + return state_dict, network_alphas + else: + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs + ) + + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Flux Control LoRAs also have norm keys + has_norm_keys = any( + norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys + ) + + if not (has_lora_keys or has_norm_keys): + raise ValueError("Invalid LoRA checkpoint.") + + transformer_lora_state_dict = { + k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k + } + transformer_norm_state_dict = { + k: state_dict.pop(k) + for k in list(state_dict.keys()) + if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + } + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) + + if has_param_with_expanded_shape: + logger.info( + "The LoRA weights contain parameters that have different shapes that expected by the transformer. " + "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " + "To get a comprehensive list of parameter names that were modified, enable debug logging." + ) + + if len(transformer_lora_state_dict) > 0: + self.load_lora_into_transformer( + transformer_lora_state_dict, + network_alphas=network_alphas, + transformer=transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + if len(transformer_norm_state_dict) > 0: + transformer._transformer_norm_layers = self._load_norm_into_transformer( + transformer_norm_state_dict, + transformer=transformer, + discard_original_layers=False, + ) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + transformer (`FluxTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + keys = list(state_dict.keys()) + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + def _load_norm_into_transformer( + cls, + state_dict, + transformer, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Find invalid keys + transformer_state_dict = transformer.state_dict() + transformer_keys = set(transformer_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - transformer_keys) + + if extra_keys: + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers_state_dict = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers_state_dict[key] = transformer_state_dict[key].clone() + + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " + "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." + ) + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys + incompatible_keys = transformer.load_state_dict(state_dict, strict=False) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + + # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. + if unexpected_keys: + if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): + raise ValueError( + f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." + ) + + return overwritten_layers_state_dict + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + **peft_kwargs, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if ( + hasattr(transformer, "_transformer_norm_layers") + and isinstance(transformer._transformer_norm_layers, dict) + and len(transformer._transformer_norm_layers.keys()) > 0 + ): + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " + "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." + ) + + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + + super().unfuse_lora(components=components) + + # We override this here account for `_transformer_norm_layers`. + def unload_lora_weights(self): + super().unload_lora_weights() + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + transformer._transformer_norm_layers = None + + @classmethod + def _maybe_expand_transformer_param_shape_or_error_( + cls, + transformer: torch.nn.Module, + lora_state_dict=None, + norm_state_dict=None, + prefix=None, + ) -> bool: + """ + Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and + generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. + """ + state_dict = {} + if lora_state_dict is not None: + state_dict.update(lora_state_dict) + if norm_state_dict is not None: + state_dict.update(norm_state_dict) + + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Expand transformer parameter shapes if they don't match lora + has_param_with_shape_update = False + + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear): + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + lora_A_weight_name = f"{name}.lora_A.weight" + lora_B_weight_name = f"{name}.lora_B.weight" + if lora_A_weight_name not in state_dict.keys(): + continue + + in_features = state_dict[lora_A_weight_name].shape[1] + out_features = state_dict[lora_B_weight_name].shape[0] + + # This means there's no need for an expansion in the params, so we simply skip. + if tuple(module_weight.shape) == (out_features, in_features): + continue + + module_out_features, module_in_features = module_weight.shape + if out_features < module_out_features or in_features < module_in_features: + raise NotImplementedError( + f"Only LoRAs with input/output features higher than the current module's input/output features " + f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " + f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " + f"this please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + logger.debug( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}, and the number of output features will be " + f"expanded from {module_out_features} to {out_features}." + ) + + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + # TODO: consider initializing this under meta device for optims. + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + expanded_module.weight.data.copy_(new_weight) + if module_bias is not None: + expanded_module.bias.data.copy_(module_bias) + + setattr(parent_module, current_module_name, expanded_module) + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") + + return has_param_with_shape_update + + +# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially +# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. +class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): + _lora_loadable_modules = ["transformer", "text_encoder"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] + network_alphas = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + if len(state_dict.keys()) > 0: + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + +class CogVideoXLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + +class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." + deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/diffusers/src/diffusers/loaders/peft.py b/diffusers/src/diffusers/loaders/peft.py new file mode 100755 index 0000000..d1c6721 --- /dev/null +++ b/diffusers/src/diffusers/loaders/peft.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from functools import partial +from typing import Dict, List, Optional, Union + +from ..utils import ( + MIN_PEFT_VERSION, + USE_PEFT_BACKEND, + check_peft_version, + delete_adapter_layers, + is_peft_available, + set_adapter_layers, + set_weights_and_activate_adapters, +) +from .unet_loader_utils import _maybe_expand_lora_scales + + +_SET_ADAPTER_SCALE_FN_MAPPING = { + "UNet2DConditionModel": _maybe_expand_lora_scales, + "UNetMotionModel": _maybe_expand_lora_scales, + "SD3Transformer2DModel": lambda model_cls, weights: weights, + "FluxTransformer2DModel": lambda model_cls, weights: weights, + "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, +} + + +class PeftAdapterMixin: + """ + A class containing all functions for loading and using adapters weights that are supported in PEFT library. For + more details about adapters and injecting them in a base model, check out the PEFT + [documentation](https://huggingface.co/docs/peft/index). + + Install the latest version of PEFT, and use this mixin to: + + - Attach new adapters in the model. + - Attach multiple adapters and iteratively activate/deactivate them. + - Activate/deactivate all adapters from the model. + - Get a list of the active adapters. + """ + + _hf_peft_config_loaded = False + + def set_adapters( + self, + adapter_names: Union[List[str], str], + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, + ): + """ + Set the currently active adapters for use in the UNet. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + adapter_weights (`Union[List[float], float]`, *optional*): + The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the + adapters. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `set_adapters()`.") + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # Expand weights into a list, one entry per adapter + # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." + ) + + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + # e.g. [{...}, 7] -> [{expanded dict...}, 7] + scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__] + weights = scale_expansion_fn(self, weights) + + set_weights_and_activate_adapters(self, adapter_names, weights) + + def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: + r""" + Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned + to the adapter to follow the convention of the PEFT library. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT + [documentation](https://huggingface.co/docs/peft). + + Args: + adapter_config (`[~peft.PeftConfig]`): + The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt + methods. + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not is_peft_available(): + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") + + from peft import PeftConfig, inject_adapter_in_model + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + elif adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + + # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is + # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here. + adapter_config.base_model_name_or_path = None + inject_adapter_in_model(adapter_config, self, adapter_name) + self.set_adapter(adapter_name) + + def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: + """ + Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + + Args: + adapter_name (Union[str, List[str]])): + The list of adapters to set or the adapter name in the case of a single adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + missing = set(adapter_name) - set(self.peft_config) + if len(missing) > 0: + raise ValueError( + f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." + f" current loaded adapters are: {list(self.peft_config.keys())}" + ) + + from peft.tuners.tuners_utils import BaseTunerLayer + + _adapters_has_been_set = False + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + # Previous versions of PEFT does not support multi-adapter inference + elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: + raise ValueError( + "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." + " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" + ) + else: + module.active_adapter = adapter_name + _adapters_has_been_set = True + + if not _adapters_has_been_set: + raise ValueError( + "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." + ) + + def disable_adapters(self) -> None: + r""" + Disable all adapters attached to the model and fallback to inference with the base model only. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + # support for older PEFT versions + module.disable_adapters = True + + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of + adapters to enable. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=True) + else: + # support for older PEFT versions + module.disable_adapters = False + + def active_adapters(self) -> List[str]: + """ + Gets the current list of active adapters of the model. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not is_peft_available(): + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + return module.active_adapter + + def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `fuse_lora()`.") + + self.lora_scale = lora_scale + self._safe_fusing = safe_fusing + self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) + + def _fuse_lora_apply(self, module, adapter_names=None): + from peft.tuners.tuners_utils import BaseTunerLayer + + merge_kwargs = {"safe_merge": self._safe_fusing} + + if isinstance(module, BaseTunerLayer): + if self.lora_scale != 1.0: + module.scale_layer(self.lora_scale) + + # For BC with prevous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" + " to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) + + def unfuse_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `unfuse_lora()`.") + self.apply(self._unfuse_lora_apply) + + def _unfuse_lora_apply(self, module): + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(module, BaseTunerLayer): + module.unmerge() + + def unload_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `unload_lora()`.") + + from ..utils import recurse_remove_peft_layers + + recurse_remove_peft_layers(self) + if hasattr(self, "peft_config"): + del self.peft_config + + def disable_lora(self): + """ + Disables the active LoRA layers of the underlying model. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.disable_lora() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + set_adapter_layers(self, enabled=False) + + def enable_lora(self): + """ + Enables the active LoRA layers of the underlying model. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.enable_lora() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + set_adapter_layers(self, enabled=True) + + def delete_adapters(self, adapter_names: Union[List[str], str]): + """ + Delete an adapter's LoRA layers from the underlying model. + + Args: + adapter_names (`Union[List[str], str]`): + The names (single string or list of strings) of the adapter to delete. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" + ) + pipeline.delete_adapters("cinematic") + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for adapter_name in adapter_names: + delete_adapter_layers(self, adapter_name) + + # Pop also the corresponding adapter from the config + if hasattr(self, "peft_config"): + self.peft_config.pop(adapter_name, None) diff --git a/diffusers/src/diffusers/loaders/single_file.py b/diffusers/src/diffusers/loaders/single_file.py new file mode 100755 index 0000000..6ede757 --- /dev/null +++ b/diffusers/src/diffusers/loaders/single_file.py @@ -0,0 +1,550 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import os + +import torch +from huggingface_hub import snapshot_download +from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args +from packaging import version + +from ..utils import deprecate, is_transformers_available, logging +from .single_file_utils import ( + SingleFileComponentError, + _is_legacy_scheduler_kwargs, + _is_model_weights_in_cached_folder, + _legacy_load_clip_tokenizer, + _legacy_load_safety_checker, + _legacy_load_scheduler, + create_diffusers_clip_model_from_ldm, + create_diffusers_t5_model_from_checkpoint, + fetch_diffusers_config, + fetch_original_config, + is_clip_model_in_single_file, + is_t5_in_single_file, + load_single_file_checkpoint, +) + + +logger = logging.get_logger(__name__) + +# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided +SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"] + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel, PreTrainedTokenizer + + +def load_single_file_sub_model( + library_name, + class_name, + name, + checkpoint, + pipelines, + is_pipeline_module, + cached_model_config_path, + original_config=None, + local_files_only=False, + torch_dtype=None, + is_legacy_loading=False, + **kwargs, +): + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + is_tokenizer = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedTokenizer) + and transformers_version >= version.parse("4.20.0") + ) + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin) + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) + is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin) + + if is_diffusers_single_file_model: + load_method = getattr(class_obj, "from_single_file") + + # We cannot provide two different config options to the `from_single_file` method + # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided + if original_config: + cached_model_config_path = None + + loaded_sub_model = load_method( + pretrained_model_link_or_path_or_dict=checkpoint, + original_config=original_config, + config=cached_model_config_path, + subfolder=name, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + **kwargs, + ) + + elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint): + loaded_sub_model = create_diffusers_clip_model_from_ldm( + class_obj, + checkpoint=checkpoint, + config=cached_model_config_path, + subfolder=name, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + is_legacy_loading=is_legacy_loading, + ) + + elif is_transformers_model and is_t5_in_single_file(checkpoint): + loaded_sub_model = create_diffusers_t5_model_from_checkpoint( + class_obj, + checkpoint=checkpoint, + config=cached_model_config_path, + subfolder=name, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + ) + + elif is_tokenizer and is_legacy_loading: + loaded_sub_model = _legacy_load_clip_tokenizer( + class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only + ) + + elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)): + loaded_sub_model = _legacy_load_scheduler( + class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs + ) + + else: + if not hasattr(class_obj, "from_pretrained"): + raise ValueError( + ( + f"The component {class_obj.__name__} cannot be loaded as it does not seem to have" + " a supported loading method." + ) + ) + + loading_kwargs = {} + loading_kwargs.update( + { + "pretrained_model_name_or_path": cached_model_config_path, + "subfolder": name, + "local_files_only": local_files_only, + } + ) + + # Schedulers and Tokenizers don't make use of torch_dtype + # Skip passing it to those objects + if issubclass(class_obj, torch.nn.Module): + loading_kwargs.update({"torch_dtype": torch_dtype}) + + if is_diffusers_model or is_transformers_model: + if not _is_model_weights_in_cached_folder(cached_model_config_path, name): + raise SingleFileComponentError( + f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + ) + + load_method = getattr(class_obj, "from_pretrained") + loaded_sub_model = load_method(**loading_kwargs) + + return loaded_sub_model + + +def _map_component_types_to_config_dict(component_types): + diffusers_module = importlib.import_module(__name__.split(".")[0]) + config_dict = {} + component_types.pop("self", None) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + for component_name, component_value in component_types.items(): + is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin) + is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers" + is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin) + + is_transformers_model = ( + is_transformers_available() + and issubclass(component_value[0], PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + is_transformers_tokenizer = ( + is_transformers_available() + and issubclass(component_value[0], PreTrainedTokenizer) + and transformers_version >= version.parse("4.20.0") + ) + + if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: + config_dict[component_name] = ["diffusers", component_value[0].__name__] + + elif is_scheduler_enum or is_scheduler: + if is_scheduler_enum: + # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler + # if the type hint is a KarrassDiffusionSchedulers enum + config_dict[component_name] = ["diffusers", "DDIMScheduler"] + + elif is_scheduler: + config_dict[component_name] = ["diffusers", component_value[0].__name__] + + elif ( + is_transformers_model or is_transformers_tokenizer + ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: + config_dict[component_name] = ["transformers", component_value[0].__name__] + + else: + config_dict[component_name] = [None, None] + + return config_dict + + +def _infer_pipeline_config_dict(pipeline_class): + parameters = inspect.signature(pipeline_class.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + component_types = pipeline_class._get_signature_types() + + # Ignore parameters that are not required for the pipeline + component_types = {k: v for k, v in component_types.items() if k in required_parameters} + config_dict = _map_component_types_to_config_dict(component_types) + + return config_dict + + +def _download_diffusers_model_config_from_hub( + pretrained_model_name_or_path, + cache_dir, + revision, + proxies, + force_download=None, + local_files_only=None, + token=None, +): + allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"] + cached_model_path = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + local_files_only=local_files_only, + token=token, + allow_patterns=allow_patterns, + ) + + return cached_model_path + + +class FromSingleFileMixin: + """ + Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` + format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + original_config_file (`str`, *optional*): + The path to the original config file that was used to train the model. If not provided, the config file + will be inferred from the checkpoint file. + config (`str`, *optional*): + Can be either: + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline + component configs in Diffusers format. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + + Examples: + + ```py + >>> from diffusers import StableDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" + ... ) + + >>> # Download pipeline from local file + >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt + >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt") + + >>> # Enable float16 and move to GPU + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", + ... torch_dtype=torch.float16, + ... ) + >>> pipeline.to("cuda") + ``` + + """ + original_config_file = kwargs.pop("original_config_file", None) + config = kwargs.pop("config", None) + original_config = kwargs.pop("original_config", None) + + if original_config_file is not None: + deprecation_message = ( + "`original_config_file` argument is deprecated and will be removed in future versions." + "please use the `original_config` argument instead." + ) + deprecate("original_config_file", "1.0.0", deprecation_message) + original_config = original_config_file + + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + + is_legacy_loading = False + + # We shouldn't allow configuring individual models components through a Pipeline creation method + # These model kwargs should be deprecated + scaling_factor = kwargs.get("scaling_factor", None) + if scaling_factor is not None: + deprecation_message = ( + "Passing the `scaling_factor` argument to `from_single_file is deprecated " + "and will be ignored in future versions." + ) + deprecate("scaling_factor", "1.0.0", deprecation_message) + + if original_config is not None: + original_config = fetch_original_config(original_config, local_files_only=local_files_only) + + from ..pipelines.pipeline_utils import _get_pipeline_class + + pipeline_class = _get_pipeline_class(cls, config=None) + + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + if config is None: + config = fetch_diffusers_config(checkpoint) + default_pretrained_model_config_name = config["pretrained_model_name_or_path"] + else: + default_pretrained_model_config_name = config + + if not os.path.isdir(default_pretrained_model_config_name): + # Provided config is a repo_id + if default_pretrained_model_config_name.count("/") > 1: + raise ValueError( + f'The provided config "{config}"' + " is neither a valid local path nor a valid repo id. Please check the parameter." + ) + try: + # Attempt to download the config files for the pipeline + cached_model_config_path = _download_diffusers_model_config_from_hub( + default_pretrained_model_config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + local_files_only=local_files_only, + token=token, + ) + config_dict = pipeline_class.load_config(cached_model_config_path) + + except LocalEntryNotFoundError: + # `local_files_only=True` but a local diffusers format model config is not available in the cache + # If `original_config` is not provided, we need override `local_files_only` to False + # to fetch the config files from the hub so that we have a way + # to configure the pipeline components. + + if original_config is None: + logger.warning( + "`local_files_only` is True but no local configs were found for this checkpoint.\n" + "Attempting to download the necessary config files for this pipeline.\n" + ) + cached_model_config_path = _download_diffusers_model_config_from_hub( + default_pretrained_model_config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + local_files_only=False, + token=token, + ) + config_dict = pipeline_class.load_config(cached_model_config_path) + + else: + # For backwards compatibility + # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components + logger.warning( + "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n" + "This may lead to errors if the model components are not correctly inferred. \n" + "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n" + "e.g. `from_single_file(, config=) \n" + "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with " + "the necessary config files.\n" + ) + is_legacy_loading = True + cached_model_config_path = None + + config_dict = _infer_pipeline_config_dict(pipeline_class) + config_dict["_class_name"] = pipeline_class.__name__ + + else: + # Provided config is a path to a local directory attempt to load directly. + cached_model_config_path = default_pretrained_model_config_name + config_dict = pipeline_class.load_config(cached_model_config_path) + + # pop out "_ignore_files" as it is only needed for download + config_dict.pop("_ignore_files", None) + + expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + from diffusers import pipelines + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + if name in SINGLE_FILE_OPTIONAL_COMPONENTS: + return False + + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + for name, (library_name, class_name) in logging.tqdm( + sorted(init_dict.items()), desc="Loading pipeline components..." + ): + loaded_sub_model = None + is_pipeline_module = hasattr(pipelines, library_name) + + if name in passed_class_obj: + loaded_sub_model = passed_class_obj[name] + + else: + try: + loaded_sub_model = load_single_file_sub_model( + library_name=library_name, + class_name=class_name, + name=name, + checkpoint=checkpoint, + is_pipeline_module=is_pipeline_module, + cached_model_config_path=cached_model_config_path, + pipelines=pipelines, + torch_dtype=torch_dtype, + original_config=original_config, + local_files_only=local_files_only, + is_legacy_loading=is_legacy_loading, + **kwargs, + ) + except SingleFileComponentError as e: + raise SingleFileComponentError( + ( + f"{e.message}\n" + f"Please load the component before passing it in as an argument to `from_single_file`.\n" + f"\n" + f"{name} = {class_name}.from_pretrained('...')\n" + f"pipe = {pipeline_class.__name__}.from_single_file(, {name}={name})\n" + f"\n" + ) + ) + + init_kwargs[name] = loaded_sub_model + + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + # deprecated kwargs + load_safety_checker = kwargs.pop("load_safety_checker", None) + if load_safety_checker is not None: + deprecation_message = ( + "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`" + "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`" + ) + deprecate("load_safety_checker", "1.0.0", deprecation_message) + + safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype) + init_kwargs.update(safety_checker_components) + + pipe = pipeline_class(**init_kwargs) + + return pipe \ No newline at end of file diff --git a/diffusers/src/diffusers/loaders/single_file_model.py b/diffusers/src/diffusers/loaders/single_file_model.py new file mode 100755 index 0000000..3fe1abf --- /dev/null +++ b/diffusers/src/diffusers/loaders/single_file_model.py @@ -0,0 +1,318 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import re +from contextlib import nullcontext +from typing import Optional + +from huggingface_hub.utils import validate_hf_hub_args + +from ..utils import deprecate, is_accelerate_available, logging +from .single_file_utils import ( + SingleFileComponentError, + convert_animatediff_checkpoint_to_diffusers, + convert_controlnet_checkpoint, + convert_flux_transformer_checkpoint_to_diffusers, + convert_ldm_unet_checkpoint, + convert_ldm_vae_checkpoint, + convert_sd3_transformer_checkpoint_to_diffusers, + convert_stable_cascade_unet_single_file_to_diffusers, + create_controlnet_diffusers_config_from_ldm, + create_unet_diffusers_config_from_ldm, + create_vae_diffusers_config_from_ldm, + fetch_diffusers_config, + fetch_original_config, + load_single_file_checkpoint, +) + + +logger = logging.get_logger(__name__) + + +if is_accelerate_available(): + from accelerate import init_empty_weights + + from ..models.modeling_utils import load_model_dict_into_meta + + +SINGLE_FILE_LOADABLE_CLASSES = { + "StableCascadeUNet": { + "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, + }, + "UNet2DConditionModel": { + "checkpoint_mapping_fn": convert_ldm_unet_checkpoint, + "config_mapping_fn": create_unet_diffusers_config_from_ldm, + "default_subfolder": "unet", + "legacy_kwargs": { + "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args + }, + }, + "AutoencoderKL": { + "checkpoint_mapping_fn": convert_ldm_vae_checkpoint, + "config_mapping_fn": create_vae_diffusers_config_from_ldm, + "default_subfolder": "vae", + }, + "ControlNetModel": { + "checkpoint_mapping_fn": convert_controlnet_checkpoint, + "config_mapping_fn": create_controlnet_diffusers_config_from_ldm, + }, + "SD3Transformer2DModel": { + "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "MotionAdapter": { + "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, + }, + "SparseControlNetModel": { + "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, + }, + "FluxTransformer2DModel": { + "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, +} + + +def _get_single_file_loadable_mapping_class(cls): + diffusers_module = importlib.import_module(__name__.split(".")[0]) + for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: + loadable_class = getattr(diffusers_module, loadable_class_str) + + if issubclass(cls, loadable_class): + return loadable_class_str + + return None + + +def _get_mapping_function_kwargs(mapping_fn, **kwargs): + parameters = inspect.signature(mapping_fn).parameters + + mapping_kwargs = {} + for parameter in parameters: + if parameter in kwargs: + mapping_kwargs[parameter] = kwargs[parameter] + + return mapping_kwargs + + +class FromOriginalModelMixin: + """ + Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs): + r""" + Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model + is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path_or_dict (`str`, *optional*): + Can be either: + - A link to the `.safetensors` or `.ckpt` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a local *file* containing the weights of the component model. + - A state dict containing the component model weights. + config (`str`, *optional*): + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted + on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component + configs in Diffusers format. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + original_config (`str`, *optional*): + Dict or path to a yaml file containing the configuration for the model in its original format. + If a dict is provided, it will be used to initialize the model configuration. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + ```py + >>> from diffusers import StableCascadeUNet + + >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors" + >>> model = StableCascadeUNet.from_single_file(ckpt_path) + ``` + """ + + mapping_class_name = _get_single_file_loadable_mapping_class(cls) + # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + if mapping_class_name is None: + raise ValueError( + f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" + ) + + pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) + if pretrained_model_link_or_path is not None: + deprecation_message = ( + "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes" + ) + deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) + pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path + + config = kwargs.pop("config", None) + original_config = kwargs.pop("original_config", None) + + if config is not None and original_config is not None: + raise ValueError( + "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments" + ) + + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + subfolder = kwargs.pop("subfolder", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + + if isinstance(pretrained_model_link_or_path_or_dict, dict): + checkpoint = pretrained_model_link_or_path_or_dict + else: + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path_or_dict, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] + + checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] + if original_config: + if "config_mapping_fn" in mapping_functions: + config_mapping_fn = mapping_functions["config_mapping_fn"] + else: + config_mapping_fn = None + + if config_mapping_fn is None: + raise ValueError( + ( + f"`original_config` has been provided for {mapping_class_name} but no mapping function" + "was found to convert the original config to a Diffusers config in" + "`diffusers.loaders.single_file_utils`" + ) + ) + + if isinstance(original_config, str): + # If original_config is a URL or filepath fetch the original_config dict + original_config = fetch_original_config(original_config, local_files_only=local_files_only) + + config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs) + diffusers_model_config = config_mapping_fn( + original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs + ) + else: + if config: + if isinstance(config, str): + default_pretrained_model_config_name = config + else: + raise ValueError( + ( + "Invalid `config` argument. Please provide a string representing a repo id" + "or path to a local Diffusers model repo." + ) + ) + + else: + config = fetch_diffusers_config(checkpoint) + default_pretrained_model_config_name = config["pretrained_model_name_or_path"] + + if "default_subfolder" in mapping_functions: + subfolder = mapping_functions["default_subfolder"] + + subfolder = subfolder or config.pop( + "subfolder", None + ) # some configs contain a subfolder key, e.g. StableCascadeUNet + + diffusers_model_config = cls.load_config( + pretrained_model_name_or_path=default_pretrained_model_config_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + + # Map legacy kwargs to new kwargs + if "legacy_kwargs" in mapping_functions: + legacy_kwargs = mapping_functions["legacy_kwargs"] + for legacy_key, new_key in legacy_kwargs.items(): + if legacy_key in kwargs: + kwargs[new_key] = kwargs.pop(legacy_key) + + model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} + diffusers_model_config.update(model_kwargs) + + checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) + diffusers_format_checkpoint = checkpoint_mapping_fn( + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs + ) + if not diffusers_format_checkpoint: + raise SingleFileComponentError( + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls.from_config(diffusers_model_config) + + if is_accelerate_available(): + unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + + else: + _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + if torch_dtype is not None: + model.to(torch_dtype) + + model.eval() + + return model diff --git a/diffusers/src/diffusers/loaders/single_file_utils.py b/diffusers/src/diffusers/loaders/single_file_utils.py new file mode 100755 index 0000000..d620c15 --- /dev/null +++ b/diffusers/src/diffusers/loaders/single_file_utils.py @@ -0,0 +1,2098 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the Stable Diffusion checkpoints.""" + +import os +import re +from contextlib import nullcontext +from io import BytesIO +from urllib.parse import urlparse + +import requests +import torch +import yaml + +from ..models.modeling_utils import load_state_dict +from ..schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EDMDPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ..utils import ( + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + deprecate, + is_accelerate_available, + is_transformers_available, + logging, +) +from ..utils.hub_utils import _get_model_file + + +if is_transformers_available(): + from transformers import AutoImageProcessor + +if is_accelerate_available(): + from accelerate import init_empty_weights + + from ..models.modeling_utils import load_model_dict_into_meta + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CHECKPOINT_KEY_NAMES = { + "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", + "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", + "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias", + "controlnet": "control_model.time_embed.0.weight", + "playground-v2-5": "edm_mean", + "inpainting": "model.diffusion_model.input_blocks.0.0.weight", + "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", + "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight", + "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight", + "open_clip": "cond_stage_model.model.token_embedding.weight", + "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding", + "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection", + "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", + "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", + "stable_cascade_stage_c": "clip_txt_mapper.weight", + "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", + "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", + "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", + "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", + "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", + "animatediff_rgb": "controlnet_cond_embedding.weight", + "flux": [ + "double_blocks.0.img_attn.norm.key_norm.scale", + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", + ], +} + +DIFFUSERS_DEFAULT_PIPELINE_PATHS = { + "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"}, + "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"}, + "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"}, + "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"}, + "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"}, + "inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"}, + "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, + "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, + "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, + "v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"}, + "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, + "stable_cascade_stage_b_lite": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade", + "subfolder": "decoder_lite", + }, + "stable_cascade_stage_c": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", + "subfolder": "prior", + }, + "stable_cascade_stage_c_lite": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", + "subfolder": "prior_lite", + }, + "sd3": { + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers", + }, + "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"}, + "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, + "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, + "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, + "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, + "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, + "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, + "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, +} + +# Use to configure model sample size when original config is provided +DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = { + "xl_base": 1024, + "xl_refiner": 1024, + "xl_inpaint": 1024, + "playground-v2-5": 1024, + "upscale": 512, + "inpainting": 512, + "inpainting_v2": 512, + "controlnet": 512, + "v2": 768, + "v1": 512, +} + + +DIFFUSERS_TO_LDM_MAPPING = { + "unet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "conv_norm_out.weight": "out.0.weight", + "conv_norm_out.bias": "out.0.bias", + "conv_out.weight": "out.2.weight", + "conv_out.bias": "out.2.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "controlnet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", + "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias", + "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight", + "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "vae": { + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_norm_out.weight": "encoder.norm_out.weight", + "encoder.conv_norm_out.bias": "encoder.norm_out.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_norm_out.weight": "decoder.norm_out.weight", + "decoder.conv_norm_out.bias": "decoder.norm_out.bias", + "quant_conv.weight": "quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + }, + "openclip": { + "layers": { + "text_model.embeddings.position_embedding.weight": "positional_embedding", + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.final_layer_norm.weight": "ln_final.weight", + "text_model.final_layer_norm.bias": "ln_final.bias", + "text_projection.weight": "text_projection", + }, + "transformer": { + "text_model.encoder.layers.": "resblocks.", + "layer_norm1": "ln_1", + "layer_norm2": "ln_2", + ".fc1.": ".c_fc.", + ".fc2.": ".c_proj.", + ".self_attn": ".attn", + "transformer.text_model.final_layer_norm.": "ln_final.", + "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding", + }, + }, +} + +SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_1.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_1.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_2.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_2.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", + "cond_stage_model.model.text_projection", +] + +# To support legacy scheduler_type argument +SCHEDULER_DEFAULT_CONFIG = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", +} + +LDM_VAE_KEYS = ["first_stage_model.", "vae."] +LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 +PLAYGROUND_VAE_SCALING_FACTOR = 0.5 +LDM_UNET_KEY = "model.diffusion_model." +LDM_CONTROLNET_KEY = "control_model." +LDM_CLIP_PREFIX_TO_REMOVE = [ + "cond_stage_model.transformer.", + "conditioner.embedders.0.transformer.", +] +LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 +SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"] + +VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] + + +class SingleFileComponentError(Exception): + def __init__(self, message=None): + self.message = message + super().__init__(self.message) + + +def is_valid_url(url): + result = urlparse(url) + if result.scheme and result.netloc: + return True + + return False + + +def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): + if not is_valid_url(pretrained_model_name_or_path): + raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") + + pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" + weights_name = None + repo_id = (None,) + for prefix in VALID_URL_PREFIXES: + pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") + match = re.match(pattern, pretrained_model_name_or_path) + if not match: + logger.warning("Unable to identify the repo_id and weights_name from the provided URL.") + return repo_id, weights_name + + repo_id = f"{match.group(1)}/{match.group(2)}" + weights_name = match.group(3) + + return repo_id, weights_name + + +def _is_model_weights_in_cached_folder(cached_folder, name): + pretrained_model_name_or_path = os.path.join(cached_folder, name) + weights_exist = False + + for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]: + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + weights_exist = True + + return weights_exist + + +def _is_legacy_scheduler_kwargs(kwargs): + return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys()) + + +def load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, +): + if os.path.isfile(pretrained_model_link_or_path): + pretrained_model_link_or_path = pretrained_model_link_or_path + + else: + repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) + pretrained_model_link_or_path = _get_model_file( + repo_id, + weights_name=weights_name, + force_download=force_download, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + + checkpoint = load_state_dict(pretrained_model_link_or_path) + + # some checkpoints contain the model state dict under a "state_dict" key + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + return checkpoint + + +def fetch_original_config(original_config_file, local_files_only=False): + if os.path.isfile(original_config_file): + with open(original_config_file, "r") as fp: + original_config_file = fp.read() + + elif is_valid_url(original_config_file): + if local_files_only: + raise ValueError( + "`local_files_only` is set to True, but a URL was provided as `original_config_file`. " + "Please provide a valid local file path." + ) + + original_config_file = BytesIO(requests.get(original_config_file).content) + + else: + raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") + + original_config = yaml.safe_load(original_config_file) + + return original_config + + +def is_clip_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip"] in checkpoint: + return True + + return False + + +def is_clip_sdxl_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint: + return True + + return False + + +def is_clip_sd3_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint: + return True + + return False + + +def is_open_clip_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint: + return True + + return False + + +def is_open_clip_sdxl_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint: + return True + + return False + + +def is_open_clip_sd3_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: + return True + + return False + + +def is_open_clip_sdxl_refiner_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint: + return True + + return False + + +def is_clip_model_in_single_file(class_obj, checkpoint): + is_clip_in_checkpoint = any( + [ + is_clip_model(checkpoint), + is_clip_sd3_model(checkpoint), + is_open_clip_model(checkpoint), + is_open_clip_sdxl_model(checkpoint), + is_open_clip_sdxl_refiner_model(checkpoint), + is_open_clip_sd3_model(checkpoint), + ] + ) + if ( + class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection" + ) and is_clip_in_checkpoint: + return True + + return False + + +def infer_diffusers_model_type(checkpoint): + if ( + CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9 + ): + if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + model_type = "inpainting_v2" + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + model_type = "xl_inpaint" + else: + model_type = "inpainting" + + elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + model_type = "v2" + + elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint: + model_type = "playground-v2-5" + + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + model_type = "xl_base" + + elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: + model_type = "xl_refiner" + + elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint: + model_type = "upscale" + + elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint: + model_type = "controlnet" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536 + ): + model_type = "stable_cascade_stage_c_lite" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048 + ): + model_type = "stable_cascade_stage_c" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576 + ): + model_type = "stable_cascade_stage_b_lite" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640 + ): + model_type = "stable_cascade_stage_b" + + elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint: + model_type = "sd3" + + elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: + if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: + model_type = "animatediff_scribble" + + elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint: + model_type = "animatediff_rgb" + + elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: + model_type = "animatediff_v2" + + elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320: + model_type = "animatediff_sdxl_beta" + + elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24: + model_type = "animatediff_v1" + + else: + model_type = "animatediff_v3" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): + if any( + g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] + ): + model_type = "flux-dev" + else: + model_type = "flux-schnell" + else: + model_type = "v1" + + return model_type + + +def fetch_diffusers_config(checkpoint): + model_type = infer_diffusers_model_type(checkpoint) + model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type] + + return model_path + + +def set_image_size(checkpoint, image_size=None): + if image_size: + return image_size + + model_type = infer_diffusers_model_type(checkpoint) + image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type] + + return image_size + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config_from_ldm( + original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None +): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if image_size is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + if num_in_channels is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + in_channels = num_in_channels + else: + in_channels = unet_params["in_channels"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": in_channels, + "down_block_types": down_block_types, + "block_out_channels": block_out_channels, + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if upcast_attention is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + config["upcast_attention"] = upcast_attention + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = up_block_types + + return config + + +def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs): + if image_size is not None: + deprecation_message = ( + "Configuring ControlNetModel with the `image_size` argument" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size) + + controlnet_config = { + "conditioning_channels": unet_params["hint_channels"], + "in_channels": diffusers_unet_config["in_channels"], + "down_block_types": diffusers_unet_config["down_block_types"], + "block_out_channels": diffusers_unet_config["block_out_channels"], + "layers_per_block": diffusers_unet_config["layers_per_block"], + "cross_attention_dim": diffusers_unet_config["cross_attention_dim"], + "attention_head_dim": diffusers_unet_config["attention_head_dim"], + "use_linear_projection": diffusers_unet_config["use_linear_projection"], + "class_embed_type": diffusers_unet_config["class_embed_type"], + "addition_embed_type": diffusers_unet_config["addition_embed_type"], + "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"], + "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"], + "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"], + } + + return controlnet_config + + +def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if image_size is not None: + deprecation_message = ( + "Configuring AutoencoderKL with the `image_size` argument" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + + if "edm_mean" in checkpoint and "edm_std" in checkpoint: + latents_mean = checkpoint["edm_mean"] + latents_std = checkpoint["edm_std"] + else: + latents_mean = None + latents_std = None + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None): + scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR + + elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]): + scaling_factor = original_config["model"]["params"]["scale_factor"] + + elif scaling_factor is None: + scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": down_block_types, + "up_block_types": up_block_types, + "block_out_channels": block_out_channels, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + "scaling_factor": scaling_factor, + } + if latents_mean is not None and latents_std is not None: + config.update({"latents_mean": latents_mean, "latents_std": latents_std}) + + return config + + +def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): + for ldm_key in ldm_keys: + diffusers_key = ( + ldm_key.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + if mapping: + diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): + for ldm_key in ldm_keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ( + ldm_key.replace(mapping["old"], mapping["new"]) + .replace("norm.weight", "group_norm.weight") + .replace("norm.bias", "group_norm.bias") + .replace("q.weight", "to_q.weight") + .replace("q.bias", "to_q.bias") + .replace("k.weight", "to_k.weight") + .replace("k.bias", "to_k.bias") + .replace("v.weight", "to_v.weight") + .replace("v.bias", "to_v.bias") + .replace("proj_out.weight", "to_out.0.weight") + .replace("proj_out.bias", "to_out.0.bias") + ) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + # proj_attn.weight has to be converted from conv 1D to linear + shape = new_checkpoint[diffusers_key].shape + + if len(shape) == 3: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0] + elif len(shape) == 4: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0] + + +def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs): + is_stage_c = "clip_txt_mapper.weight" in checkpoint + + if is_stage_c: + state_dict = {} + for key in checkpoint.keys(): + if key.endswith("in_proj_weight"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = checkpoint[key] + else: + state_dict = {} + for key in checkpoint.keys(): + if key.endswith("in_proj_weight"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + # rename clip_mapper to clip_txt_pooled_mapper + elif key.endswith("clip_mapper.weight"): + weights = checkpoint[key] + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights + elif key.endswith("clip_mapper.bias"): + weights = checkpoint[key] + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights + else: + state_dict[key] = checkpoint[key] + + return state_dict + + +def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + unet_key = LDM_UNET_KEY + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning("Checkpoint has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] + for diffusers_key, ldm_key in ldm_unet_keys.items(): + if ldm_key not in unet_state_dict: + continue + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]): + class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"] + for diffusers_key, ldm_key in class_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"): + addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"] + for diffusers_key, ldm_key in addition_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + unet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + unet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) + + # Up Blocks + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + + resnets = [ + key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key + ] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + if f"output_blocks.{i}.1.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.bias" + ] + if f"output_blocks.{i}.2.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.bias" + ] + + return new_checkpoint + + +def convert_controlnet_checkpoint( + checkpoint, + config, + **kwargs, +): + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + controlnet_state_dict = checkpoint + + else: + controlnet_state_dict = {} + keys = list(checkpoint.keys()) + controlnet_key = LDM_CONTROLNET_KEY + for key in keys: + if key.startswith(controlnet_key): + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] + for diffusers_key, ldm_key in ldm_controlnet_keys.items(): + if ldm_key not in controlnet_state_dict: + continue + new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} + ) + input_blocks = { + layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # controlnet down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} + ) + middle_blocks = { + layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in controlnet_state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.bias" + ) + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "" + for ldm_vae_key in LDM_VAE_KEYS: + if any(k.startswith(ldm_vae_key) for k in keys): + vae_key = ldm_vae_key + + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"] + for diffusers_key, ldm_key in vae_diffusers_ldm_map.items(): + if ldm_key not in vae_state_dict: + continue + new_checkpoint[diffusers_key] = vae_state_dict[ldm_key] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len(config["down_block_types"]) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, + ) + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get( + f"encoder.down.{i}.downsample.conv.bias" + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len(config["up_block_types"]) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}, + ) + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + conv_attn_to_linear(new_checkpoint) + + return new_checkpoint + + +def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = [] + remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) + if remove_prefix: + remove_prefixes.append(remove_prefix) + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def convert_open_clip_checkpoint( + text_model, + checkpoint, + prefix="cond_stage_model.model.", +): + text_model_dict = {} + text_proj_key = prefix + "text_projection" + + if text_proj_key in checkpoint: + text_proj_dim = int(checkpoint[text_proj_key].shape[0]) + elif hasattr(text_model.config, "projection_dim"): + text_proj_dim = text_model.config.projection_dim + else: + text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM + + keys = list(checkpoint.keys()) + keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE + + openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] + for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): + ldm_key = prefix + ldm_key + if ldm_key not in checkpoint: + continue + if ldm_key in keys_to_ignore: + continue + if ldm_key.endswith("text_projection"): + text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() + else: + text_model_dict[diffusers_key] = checkpoint[ldm_key] + + for key in keys: + if key in keys_to_ignore: + continue + + if not key.startswith(prefix + "transformer."): + continue + + diffusers_key = key.replace(prefix + "transformer.", "") + transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] + for new_key, old_key in transformer_diffusers_to_ldm_map.items(): + diffusers_key = ( + diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") + ) + + if key.endswith(".in_proj_weight"): + weight_value = checkpoint.get(key) + + text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach() + text_model_dict[diffusers_key + ".k_proj.weight"] = ( + weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach() + ) + text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach() + + elif key.endswith(".in_proj_bias"): + weight_value = checkpoint.get(key) + text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach() + text_model_dict[diffusers_key + ".k_proj.bias"] = ( + weight_value[text_proj_dim : text_proj_dim * 2].clone().detach() + ) + text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach() + else: + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def create_diffusers_clip_model_from_ldm( + cls, + checkpoint, + subfolder="", + config=None, + torch_dtype=None, + local_files_only=None, + is_legacy_loading=False, +): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + # For backwards compatibility + # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo + # in the cache_dir, rather than in a subfolder of the Diffusers model + if is_legacy_loading: + logger.warning( + ( + "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update " + "the local cache directory with the necessary CLIP model config files. " + "Attempting to load CLIP model from legacy cache directory." + ) + ) + + if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): + clip_config = "openai/clip-vit-large-patch14" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + elif is_open_clip_model(checkpoint): + clip_config = "stabilityai/stable-diffusion-2" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "text_encoder" + + else: + clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls(model_config) + + position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1] + + if is_clip_model(checkpoint): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) + + elif ( + is_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) + + elif ( + is_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") + diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim) + + elif is_open_clip_model(checkpoint): + prefix = "cond_stage_model.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif ( + is_open_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim + ): + prefix = "conditioner.embedders.1.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif is_open_clip_sdxl_refiner_model(checkpoint): + prefix = "conditioner.embedders.0.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif ( + is_open_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") + + else: + raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") + + if is_accelerate_available(): + unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + else: + _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + if torch_dtype is not None: + model.to(torch_dtype) + + model.eval() + + return model + + +def _legacy_load_scheduler( + cls, + checkpoint, + component_name, + original_config=None, + **kwargs, +): + scheduler_type = kwargs.get("scheduler_type", None) + prediction_type = kwargs.get("prediction_type", None) + + if scheduler_type is not None: + deprecation_message = ( + "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n" + "Example:\n\n" + "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" + "scheduler = DDIMScheduler()\n" + "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" + ) + deprecate("scheduler_type", "1.0.0", deprecation_message) + + if prediction_type is not None: + deprecation_message = ( + "Please configure an instance of a Scheduler with the appropriate `prediction_type` and " + "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n" + "Example:\n\n" + "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" + 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n' + "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" + ) + deprecate("prediction_type", "1.0.0", deprecation_message) + + scheduler_config = SCHEDULER_DEFAULT_CONFIG + model_type = infer_diffusers_model_type(checkpoint=checkpoint) + + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + + if original_config: + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000) + else: + num_train_timesteps = 1000 + + scheduler_config["num_train_timesteps"] = num_train_timesteps + + if model_type == "v2": + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + + else: + prediction_type = prediction_type or "epsilon" + + scheduler_config["prediction_type"] = prediction_type + + if model_type in ["xl_base", "xl_refiner"]: + scheduler_type = "euler" + elif model_type == "playground": + scheduler_type = "edm_dpm_solver_multistep" + else: + if original_config: + beta_start = original_config["model"]["params"].get("linear_start") + beta_end = original_config["model"]["params"].get("linear_end") + + else: + beta_start = 0.02 + beta_end = 0.085 + + scheduler_config["beta_start"] = beta_start + scheduler_config["beta_end"] = beta_end + scheduler_config["beta_schedule"] = "scaled_linear" + scheduler_config["clip_sample"] = False + scheduler_config["set_alpha_to_one"] = False + + # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers + if component_name == "low_res_scheduler": + return cls.from_config( + { + "beta_end": 0.02, + "beta_schedule": "scaled_linear", + "beta_start": 0.0001, + "clip_sample": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "trained_betas": None, + "variance_type": "fixed_small", + } + ) + + if scheduler_type is None: + return cls.from_config(scheduler_config) + + elif scheduler_type == "pndm": + scheduler_config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(scheduler_config) + + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) + + elif scheduler_type == "ddim": + scheduler = DDIMScheduler.from_config(scheduler_config) + + elif scheduler_type == "edm_dpm_solver_multistep": + scheduler_config = { + "algorithm_type": "dpmsolver++", + "dynamic_thresholding_ratio": 0.995, + "euler_at_final": False, + "final_sigmas_type": "zero", + "lower_order_final": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sample_max_value": 1.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "solver_order": 2, + "solver_type": "midpoint", + "thresholding": False, + } + scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config) + + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + return scheduler + + +def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): + clip_config = "openai/clip-vit-large-patch14" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + elif is_open_clip_model(checkpoint): + clip_config = "stabilityai/stable-diffusion-2" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "tokenizer" + + else: + clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + + return tokenizer + + +def _legacy_load_safety_checker(local_files_only, torch_dtype): + # Support for loading safety checker components using the deprecated + # `load_safety_checker` argument. + + from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + feature_extractor = AutoImageProcessor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype + ) + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype + ) + + return {"safety_checker": safety_checker, "feature_extractor": feature_extractor} + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 + caption_projection_dim = 1536 + + # Positional and patch embeddings. + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + + # Context projections. + converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias") + + # Pooled context projection. + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight") + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias") + + # Transformer blocks 🎸. + for i in range(num_layers): + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0 + ) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0 + ) + + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias]) + + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.bias" + ) + + # norms. + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" + ) + else: + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), + dim=caption_projection_dim, + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), + dim=caption_projection_dim, + ) + + # ffs. + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.bias" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim + ) + + return converted_state_dict + + +def is_t5_in_single_file(checkpoint): + if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint: + return True + + return False + + +def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = ["text_encoders.t5xxl.transformer."] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def create_diffusers_t5_model_from_checkpoint( + cls, + checkpoint, + subfolder="", + config=None, + torch_dtype=None, + local_files_only=None, +): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls(model_config) + + diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) + + if is_accelerate_available(): + unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + model.load_state_dict(diffusers_format_checkpoint) + + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + return model + + +def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + for k, v in checkpoint.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 + num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 + mlp_ratio = 4.0 + inner_dim = 3072 + + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( + "vector_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") + + # guidance + has_guidance = any("guidance" in k for k in checkpoint) + if has_guidance: + converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( + "guidance_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( + "guidance_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( + "guidance_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( + "guidance_in.out_layer.bias" + ) + + # context_embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") + + # x_embedder + converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # norms. + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") + converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") + + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias") + ) + + return converted_state_dict diff --git a/diffusers/src/diffusers/loaders/textual_inversion.py b/diffusers/src/diffusers/loaders/textual_inversion.py new file mode 100755 index 0000000..f34ea23 --- /dev/null +++ b/diffusers/src/diffusers/loaders/textual_inversion.py @@ -0,0 +1,580 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Union + +import safetensors +import torch +from huggingface_hub.utils import validate_hf_hub_args +from torch import nn + +from ..models.modeling_utils import load_state_dict +from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging + + +if is_transformers_available(): + from transformers import PreTrainedModel, PreTrainedTokenizer + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + +TEXT_INVERSION_NAME = "learned_embeds.bin" +TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" + + +@validate_hf_hub_args +def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "text_inversion", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path in pretrained_model_name_or_paths: + if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)): + # 3.1. Load textual inversion file + model_file = None + + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path + + state_dicts.append(state_dict) + + return state_dicts + + +class TextualInversionLoaderMixin: + r""" + Load Textual Inversion tokens and embeddings to the tokenizer and text encoder. + """ + + def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821 + r""" + Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to + be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or if the textual inversion token is a single vector, the input prompt is returned. + + Parameters: + prompt (`str` or list of `str`): + The prompt or prompts to guide the image generation. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str` or list of `str`: The converted prompt + """ + if not isinstance(prompt, List): + prompts = [prompt] + else: + prompts = prompt + + prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] + + if not isinstance(prompt, List): + return prompts[0] + + return prompts + + def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821 + r""" + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. + + Parameters: + prompt (`str`): + The prompt to guide the image generation. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str`: The converted prompt + """ + tokens = tokenizer.tokenize(prompt) + unique_tokens = set(tokens) + for token in unique_tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f" {token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens): + if tokenizer is None: + raise ValueError( + f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + if text_encoder is None: + raise ValueError( + f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens): + raise ValueError( + f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} " + f"Make sure both lists have the same length." + ) + + valid_tokens = [t for t in tokens if t is not None] + if len(set(valid_tokens)) < len(valid_tokens): + raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") + + @staticmethod + def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer): + all_tokens = [] + all_embeddings = [] + for state_dict, token in zip(state_dicts, tokens): + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError( + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + ) + loaded_token = token + embedding = state_dict + elif len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict.items())) + elif "string_to_param" in state_dict: + # A1111 + loaded_token = state_dict["name"] + embedding = state_dict["string_to_param"]["*"] + else: + raise ValueError( + f"Loaded state dictionary is incorrect: {state_dict}. \n\n" + "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`" + " input key." + ) + + if token is not None and loaded_token != token: + logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token + + if token in tokenizer.get_vocab(): + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." + ) + + all_tokens.append(token) + all_embeddings.append(embedding) + + return all_tokens, all_embeddings + + @staticmethod + def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer): + all_tokens = [] + all_embeddings = [] + + for embedding, token in zip(embeddings, tokens): + if f"{token}_1" in tokenizer.get_vocab(): + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 + + raise ValueError( + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) + + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 + if is_multi_vector: + all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] + all_embeddings += [e for e in embedding] # noqa: C416 + else: + all_tokens += [token] + all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding] + + return all_tokens, all_embeddings + + @validate_hf_hub_args + def load_textual_inversion( + self, + pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], + token: Optional[Union[str, List[str]]] = None, + tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821 + text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 + **kwargs, + ): + r""" + Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and + Automatic1111 formats are supported). + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`): + Can be either one of the following or a list of them: + + - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a + pretrained model hosted on the Hub. + - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual + inversion weights. + - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + token (`str` or `List[str]`, *optional*): + Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a + list, then `token` must also be a list of equal length. + text_encoder ([`~transformers.CLIPTextModel`], *optional*): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + If not specified, function will take self.tokenizer. + tokenizer ([`~transformers.CLIPTokenizer`], *optional*): + A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer. + weight_name (`str`, *optional*): + Name of a custom weight file. This should be used when: + + - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight + name such as `text_inv.bin`. + - The saved textual inversion file is in the Automatic1111 format. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + + Example: + + To load a Textual Inversion embedding vector in 🤗 Diffusers format: + + ```py + from diffusers import StableDiffusionPipeline + import torch + + model_id = "runwayml/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + + pipe.load_textual_inversion("sd-concepts-library/cat-toy") + + prompt = "A backpack" + + image = pipe(prompt, num_inference_steps=50).images[0] + image.save("cat-backpack.png") + ``` + + To load a Textual Inversion embedding vector in Automatic1111 format, make sure to download the vector first + (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector + locally: + + ```py + from diffusers import StableDiffusionPipeline + import torch + + model_id = "runwayml/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + + pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2") + + prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details." + + image = pipe(prompt, num_inference_steps=50).images[0] + image.save("character.png") + ``` + + """ + # 1. Set correct tokenizer and text encoder + tokenizer = tokenizer or getattr(self, "tokenizer", None) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + + # 2. Normalize inputs + pretrained_model_name_or_paths = ( + [pretrained_model_name_or_path] + if not isinstance(pretrained_model_name_or_path, list) + else pretrained_model_name_or_path + ) + tokens = [token] if not isinstance(token, list) else token + if tokens[0] is None: + tokens = tokens * len(pretrained_model_name_or_paths) + + # 3. Check inputs + self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens) + + # 4. Load state dicts of textual embeddings + state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) + + # 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens + if len(tokens) > 1 and len(state_dicts) == 1: + if isinstance(state_dicts[0], torch.Tensor): + state_dicts = list(state_dicts[0]) + if len(tokens) != len(state_dicts): + raise ValueError( + f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} " + f"Make sure both have the same length." + ) + + # 4. Retrieve tokens and embeddings + tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer) + + # 5. Extend tokens and embeddings for multi vector + tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer) + + # 6. Make sure all embeddings have the correct size + expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1] + if any(expected_emb_dim != emb.shape[-1] for emb in embeddings): + raise ValueError( + "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding " + "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} " + ) + + # 7. Now we can be sure that loading the embedding matrix works + # < Unsafe code: + + # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again + is_model_cpu_offload = False + is_sequential_cpu_offload = False + if self.hf_device_map is None: + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + logger.info( + "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + # 7.2 save expected device and dtype + device = text_encoder.device + dtype = text_encoder.dtype + + # 7.3 Increase token embedding matrix + text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens)) + input_embeddings = text_encoder.get_input_embeddings().weight + + # 7.4 Load token and embedding + for token, embedding in zip(tokens, embeddings): + # add tokens and get ids + tokenizer.add_tokens(token) + token_id = tokenizer.convert_tokens_to_ids(token) + input_embeddings.data[token_id] = embedding + logger.info(f"Loaded textual inversion embedding for {token}.") + + input_embeddings.to(dtype=dtype, device=device) + + # 7.5 Offload the model again + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + + # / Unsafe Code > + + def unload_textual_inversion( + self, + tokens: Optional[Union[str, List[str]]] = None, + tokenizer: Optional["PreTrainedTokenizer"] = None, + text_encoder: Optional["PreTrainedModel"] = None, + ): + r""" + Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`] + + Example: + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5") + + # Example 1 + pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork") + pipeline.load_textual_inversion("sd-concepts-library/moeb-style") + + # Remove all token embeddings + pipeline.unload_textual_inversion() + + # Example 2 + pipeline.load_textual_inversion("sd-concepts-library/moeb-style") + pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork") + + # Remove just one token + pipeline.unload_textual_inversion("") + + # Example 3: unload from SDXL + pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + embedding_path = hf_hub_download( + repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model" + ) + + # load embeddings to the text encoders + state_dict = load_file(embedding_path) + + # load embeddings of text_encoder 1 (CLIP ViT-L/14) + pipeline.load_textual_inversion( + state_dict["clip_l"], + tokens=["", ""], + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + ) + # load embeddings of text_encoder 2 (CLIP ViT-G/14) + pipeline.load_textual_inversion( + state_dict["clip_g"], + tokens=["", ""], + text_encoder=pipeline.text_encoder_2, + tokenizer=pipeline.tokenizer_2, + ) + + # Unload explicitly from both text encoders and tokenizers + pipeline.unload_textual_inversion( + tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer + ) + pipeline.unload_textual_inversion( + tokens=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2 + ) + ``` + """ + + tokenizer = tokenizer or getattr(self, "tokenizer", None) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + + # Get textual inversion tokens and ids + token_ids = [] + last_special_token_id = None + + if tokens: + if isinstance(tokens, str): + tokens = [tokens] + for added_token_id, added_token in tokenizer.added_tokens_decoder.items(): + if not added_token.special: + if added_token.content in tokens: + token_ids.append(added_token_id) + else: + last_special_token_id = added_token_id + if len(token_ids) == 0: + raise ValueError("No tokens to remove found") + else: + tokens = [] + for added_token_id, added_token in tokenizer.added_tokens_decoder.items(): + if not added_token.special: + token_ids.append(added_token_id) + tokens.append(added_token.content) + else: + last_special_token_id = added_token_id + + # Delete from tokenizer + for token_id, token_to_remove in zip(token_ids, tokens): + del tokenizer._added_tokens_decoder[token_id] + del tokenizer._added_tokens_encoder[token_to_remove] + + # Make all token ids sequential in tokenizer + key_id = 1 + for token_id in tokenizer.added_tokens_decoder: + if token_id > last_special_token_id and token_id > last_special_token_id + key_id: + token = tokenizer._added_tokens_decoder[token_id] + tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token + del tokenizer._added_tokens_decoder[token_id] + tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id + key_id += 1 + tokenizer._update_trie() + # set correct total vocab size after removing tokens + tokenizer._update_total_vocab_size() + + # Delete from text encoder + text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim + temp_text_embedding_weights = text_encoder.get_input_embeddings().weight + text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1] + to_append = [] + for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]): + if i not in token_ids: + to_append.append(temp_text_embedding_weights[i].unsqueeze(0)) + if len(to_append) > 0: + to_append = torch.cat(to_append, dim=0) + text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0) + text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim) + text_embeddings_filtered.weight.data = text_embedding_weights + text_encoder.set_input_embeddings(text_embeddings_filtered) \ No newline at end of file diff --git a/diffusers/src/diffusers/loaders/unet.py b/diffusers/src/diffusers/loaders/unet.py new file mode 100755 index 0000000..32ace77 --- /dev/null +++ b/diffusers/src/diffusers/loaders/unet.py @@ -0,0 +1,906 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from typing import Callable, Dict, Union + +import safetensors +import torch +import torch.nn.functional as F +from huggingface_hub.utils import validate_hf_hub_args +from torch import nn + +from ..models.embeddings import ( + ImageProjection, + IPAdapterFaceIDImageProjection, + IPAdapterFaceIDPlusImageProjection, + IPAdapterFullImageProjection, + IPAdapterPlusImageProjection, + MultiIPAdapterImageProjection, +) +from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict +from ..utils import ( + USE_PEFT_BACKEND, + _get_model_file, + convert_unet_state_dict_to_peft, + get_adapter_name, + get_peft_kwargs, + is_accelerate_available, + is_peft_version, + is_torch_version, + logging, +) +from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME +from .utils import AttnProcsLayers + + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" +CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" + + +class UNet2DConditionLoadersMixin: + """ + Load LoRA layers into a [`UNet2DCondtionModel`]. + """ + + text_encoder_name = TEXT_ENCODER_NAME + unet_name = UNET_NAME + + @validate_hf_hub_args + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be + defined in + [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py) + and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install + `peft`: `pip install -U peft`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a directory (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + adapter_name (`str`, *optional*, defaults to None): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.unet.load_attn_procs( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + adapter_name = kwargs.pop("adapter_name", None) + _pipeline = kwargs.pop("_pipeline", None) + network_alphas = kwargs.pop("network_alphas", None) + allow_pickle = False + + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) + is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if is_custom_diffusion: + attn_processors = self._process_custom_diffusion(state_dict=state_dict) + elif is_lora: + is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( + state_dict=state_dict, + unet_identifier_key=self.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + ) + else: + raise ValueError( + f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training." + ) + + # + + def _process_custom_diffusion(self, state_dict): + from ..models.attention_processor import CustomDiffusionAttnProcessor + + attn_processors = {} + custom_diffusion_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + if len(value) == 0: + custom_diffusion_grouped_dict[key] = {} + else: + if "to_out" in key: + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + else: + attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) + custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in custom_diffusion_grouped_dict.items(): + if len(value_dict) == 0: + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None + ) + else: + cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] + hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] + train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=True, + train_q_out=train_q_out, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ) + attn_processors[key].load_state_dict(value_dict) + + return attn_processors + + def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + # This method does the following things: + # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy + # format. For legacy format no filtering is applied. + # 2. Converts the `state_dict` to the `peft` compatible format. + # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the + # `LoraConfig` specs. + # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it. + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + unet_keys = [k for k in keys if k.startswith(unet_identifier_key)] + unet_state_dict = { + k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys + } + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)] + network_alphas = { + k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict + + if len(state_dict_to_be_used) > 0: + if adapter_name in getattr(self, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." + ) + + state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) + + if network_alphas is not None: + # The alphas state dict have the same structure as Unet, thus we convert it to peft format using + # `convert_unet_state_dict_to_peft` method. + network_alphas = convert_unet_state_dict_to_peft(network_alphas) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + return is_model_cpu_offload, is_sequential_cpu_offload + + @classmethod + # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + def save_attn_procs( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + **kwargs, + ): + r""" + Save attention processor layers to a directory so that it can be reloaded with the + [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save an attention processor to (will be created if it doesn't exist). + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or with `pickle`. + + Example: + + ```py + import torch + from diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + torch_dtype=torch.float16, + ).to("cuda") + pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") + pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") + ``` + """ + from ..models.attention_processor import ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ) + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + is_custom_diffusion = any( + isinstance( + x, + (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), + ) + for (_, x) in self.attn_processors.items() + ) + if is_custom_diffusion: + state_dict = self._get_custom_diffusion_state_dict() + if save_function is None and safe_serialization: + # safetensors does not support saving dicts with non-tensor values + empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)} + if len(empty_state_dict) > 0: + logger.warning( + f"Safetensors does not support saving dicts with non-tensor values. " + f"The following keys will be ignored: {empty_state_dict.keys()}" + ) + state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} + else: + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.") + + from peft.utils import get_peft_model_state_dict + + state_dict = get_peft_model_state_dict(self) + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + if weight_name is None: + if safe_serialization: + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE + else: + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME + + # Save the model + save_path = Path(save_directory, weight_name).as_posix() + save_function(state_dict, save_path) + logger.info(f"Model weights saved in {save_path}") + + def _get_custom_diffusion_state_dict(self): + from ..models.attention_processor import ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ) + + model_to_save = AttnProcsLayers( + { + y: x + for (y, x) in self.attn_processors.items() + if isinstance( + x, + ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ), + ) + } + ) + state_dict = model_to_save.state_dict() + for name, attn in self.attn_processors.items(): + if len(attn.state_dict()) == 0: + state_dict[name] = {} + + return state_dict + + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + updated_state_dict = {} + image_projection = None + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + + if "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 + + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value + + elif "proj.3.weight" in state_dict: + # IP-Adapter Full + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + cross_attention_dim = state_dict["proj.3.weight"].shape[0] + + with init_context(): + image_projection = IPAdapterFullImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + diffusers_name = diffusers_name.replace("proj.3", "norm") + updated_state_dict[diffusers_name] = value + + elif "perceiver_resampler.proj_in.weight" in state_dict: + # IP-Adapter Face ID Plus + id_embeddings_dim = state_dict["proj.0.weight"].shape[1] + embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0] + hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1] + output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0] + heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64 + + with init_context(): + image_projection = IPAdapterFaceIDPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + id_embeddings_dim=id_embeddings_dim, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("perceiver_resampler.", "") + diffusers_name = diffusers_name.replace("0.to", "attn.to") + diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.") + diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.") + diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.") + diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.") + diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0") + diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1") + diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0") + diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1") + diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0") + diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1") + diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0") + diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + elif "proj.0.weight" == diffusers_name: + updated_state_dict["proj.net.0.proj.weight"] = value + elif "proj.0.bias" == diffusers_name: + updated_state_dict["proj.net.0.proj.bias"] = value + elif "proj.2.weight" == diffusers_name: + updated_state_dict["proj.net.2.weight"] = value + elif "proj.2.bias" == diffusers_name: + updated_state_dict["proj.net.2.bias"] = value + else: + updated_state_dict[diffusers_name] = value + + elif "norm.weight" in state_dict: + # IP-Adapter Face ID + id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = id_embeddings_dim_out // id_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + + with init_context(): + image_projection = IPAdapterFaceIDImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=id_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + updated_state_dict[diffusers_name] = value + + else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["latents"].shape[1] + embed_dims = state_dict["proj_in.weight"].shape[1] + output_dims = state_dict["proj_out.weight"].shape[0] + hidden_dims = state_dict["latents"].shape[2] + attn_key_present = any("attn" in k for k in state_dict) + heads = ( + state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 + if attn_key_present + else state_dict["layers.0.0.to_q.weight"].shape[0] // 64 + ) + + with init_context(): + image_projection = IPAdapterPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("0.to", "2.to") + + diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0") + diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1") + diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0") + diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1") + diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0") + diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1") + diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0") + diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1") + + if "to_kv" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_q" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + updated_state_dict[diffusers_name] = value + elif "to_out" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + else: + diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0") + diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2") + + diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0") + diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2") + + diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0") + diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2") + + diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0") + diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2") + updated_state_dict[diffusers_name] = value + + if not low_cpu_mem_usage: + image_projection.load_state_dict(updated_state_dict, strict=True) + else: + load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + + return image_projection + + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): + from ..models.attention_processor import ( + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + ) + + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 1 + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + for name in self.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.config.block_out_channels[block_id] + + if cross_attention_dim is None or "motion_modules" in name: + attn_processor_class = self.attn_processors[name].__class__ + attn_procs[name] = attn_processor_class() + + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor + ) + num_image_text_embeds = [] + for state_dict in state_dicts: + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + num_image_text_embeds += [4] + elif "proj.3.weight" in state_dict["image_proj"]: + # IP-Adapter Full Face + num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID Plus + num_image_text_embeds += [4] + elif "norm.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID + num_image_text_embeds += [4] + else: + # IP-Adapter Plus + num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] + + with init_context(): + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + ) + + value_dict = {} + for i, state_dict in enumerate(state_dicts): + value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(value_dict) + else: + device = next(iter(value_dict.values())).device + dtype = next(iter(value_dict.values())).dtype + load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + + key_id += 2 + + return attn_procs + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + + # Kolors Unet already has a `encoder_hid_proj` + if ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_proj" + and not hasattr(self, "text_encoder_hid_proj") + ): + self.text_encoder_hid_proj = self.encoder_hid_proj + + # Set encoder_hid_proj after loading ip_adapter weights, + # because `IPAdapterPlusImageProjection` also has `attn_processors`. + self.encoder_hid_proj = None + + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + # convert IP-Adapter Image Projection layers to diffusers + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" + + self.to(dtype=self.dtype, device=self.device) + + def _load_ip_adapter_loras(self, state_dicts): + lora_dicts = {} + for key_id, name in enumerate(self.attn_processors.keys()): + for i, state_dict in enumerate(state_dicts): + if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: + if i not in lora_dicts: + lora_dicts[i] = {} + lora_dicts[i].update( + { + f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} + ) + lora_dicts[i].update( + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} + ) + lora_dicts[i].update( + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.up.weight" + ] + } + ) + return lora_dicts diff --git a/diffusers/src/diffusers/loaders/unet_loader_utils.py b/diffusers/src/diffusers/loaders/unet_loader_utils.py new file mode 100755 index 0000000..8f202ed --- /dev/null +++ b/diffusers/src/diffusers/loaders/unet_loader_utils.py @@ -0,0 +1,163 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import TYPE_CHECKING, Dict, List, Union + +from ..utils import logging + + +if TYPE_CHECKING: + # import here to avoid circular imports + from ..models import UNet2DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _translate_into_actual_layer_name(name): + """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')""" + if name == "mid": + return "mid_block.attentions.0" + + updown, block, attn = name.split(".") + + updown = updown.replace("down", "down_blocks").replace("up", "up_blocks") + block = block.replace("block_", "") + attn = "attentions." + attn + + return ".".join((updown, block, attn)) + + +def _maybe_expand_lora_scales( + unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0 +): + blocks_with_transformer = { + "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], + "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], + } + transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1} + + expanded_weight_scales = [ + _maybe_expand_lora_scales_for_one_adapter( + weight_for_adapter, + blocks_with_transformer, + transformer_per_block, + unet.state_dict(), + default_scale=default_scale, + ) + for weight_for_adapter in weight_scales + ] + + return expanded_weight_scales + + +def _maybe_expand_lora_scales_for_one_adapter( + scales: Union[float, Dict], + blocks_with_transformer: Dict[str, int], + transformer_per_block: Dict[str, int], + state_dict: None, + default_scale: float = 1.0, +): + """ + Expands the inputs into a more granular dictionary. See the example below for more details. + + Parameters: + scales (`Union[float, Dict]`): + Scales dict to expand. + blocks_with_transformer (`Dict[str, int]`): + Dict with keys 'up' and 'down', showing which blocks have transformer layers + transformer_per_block (`Dict[str, int]`): + Dict with keys 'up' and 'down', showing how many transformer layers each block has + + E.g. turns + ```python + scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}} + blocks_with_transformer = {"down": [1, 2], "up": [0, 1]} + transformer_per_block = {"down": 2, "up": 3} + ``` + into + ```python + { + "down.block_1.0": 2, + "down.block_1.1": 2, + "down.block_2.0": 2, + "down.block_2.1": 2, + "mid": 3, + "up.block_0.0": 4, + "up.block_0.1": 4, + "up.block_0.2": 4, + "up.block_1.0": 5, + "up.block_1.1": 6, + "up.block_1.2": 7, + } + ``` + """ + if sorted(blocks_with_transformer.keys()) != ["down", "up"]: + raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`") + + if sorted(transformer_per_block.keys()) != ["down", "up"]: + raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`") + + if not isinstance(scales, dict): + # don't expand if scales is a single number + return scales + + scales = copy.deepcopy(scales) + + if "mid" not in scales: + scales["mid"] = default_scale + elif isinstance(scales["mid"], list): + if len(scales["mid"]) == 1: + scales["mid"] = scales["mid"][0] + else: + raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.") + + for updown in ["up", "down"]: + if updown not in scales: + scales[updown] = default_scale + + # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}} + if not isinstance(scales[updown], dict): + scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]} + + # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}} + for i in blocks_with_transformer[updown]: + block = f"block_{i}" + # set not assigned blocks to default scale + if block not in scales[updown]: + scales[updown][block] = default_scale + if not isinstance(scales[updown][block], list): + scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] + elif len(scales[updown][block]) == 1: + # a list specifying scale to each masked IP input + scales[updown][block] = scales[updown][block] * transformer_per_block[updown] + elif len(scales[updown][block]) != transformer_per_block[updown]: + raise ValueError( + f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}." + ) + + # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1} + for i in blocks_with_transformer[updown]: + block = f"block_{i}" + for tf_idx, value in enumerate(scales[updown][block]): + scales[f"{updown}.{block}.{tf_idx}"] = value + + del scales[updown] + + for layer in scales.keys(): + if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): + raise ValueError( + f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions." + ) + + return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()} diff --git a/diffusers/src/diffusers/loaders/utils.py b/diffusers/src/diffusers/loaders/utils.py new file mode 100755 index 0000000..142d72b --- /dev/null +++ b/diffusers/src/diffusers/loaders/utils.py @@ -0,0 +1,59 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import torch + + +class AttnProcsLayers(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = dict(enumerate(state_dict.keys())) + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + + raise ValueError( + f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." + ) + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = remap_key(key, state_dict) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) diff --git a/diffusers/src/diffusers/models/README.md b/diffusers/src/diffusers/models/README.md new file mode 100755 index 0000000..fb91f59 --- /dev/null +++ b/diffusers/src/diffusers/models/README.md @@ -0,0 +1,3 @@ +# Models + +For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview). \ No newline at end of file diff --git a/diffusers/src/diffusers/models/__init__.py b/diffusers/src/diffusers/models/__init__.py new file mode 100755 index 0000000..5e00dae --- /dev/null +++ b/diffusers/src/diffusers/models/__init__.py @@ -0,0 +1,141 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, + is_flax_available, + is_torch_available, +) + + +_import_structure = {} + +if is_torch_available(): + _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter", "FullAdapter"] + _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] + _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] + _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] + _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] + _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] + _import_structure["autoencoders.vq_model"] = ["VQModel"] + _import_structure["branch_cogvideox"] = ["CogvideoXBranchModel"] + _import_structure["controlnet"] = ["ControlNetModel"] + _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] + _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] + _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] + _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"] + _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["embeddings"] = ["ImageProjection"] + _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] + _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] + _import_structure["transformers.cogvideox_transformer_3d_inpainting"] = ["CogVideoXTransformer3DInpaintModel"] + _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] + _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] + _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] + _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"] + _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] + _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] + _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] + _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] + _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] + _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] + _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] + _import_structure["unets.unet_1d"] = ["UNet1DModel"] + _import_structure["unets.unet_2d"] = ["UNet2DModel"] + _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] + _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"] + _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"] + _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] + _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] + _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] + _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"] + _import_structure["unets.uvit_2d"] = ["UVit2DModel"] + +if is_flax_available(): + _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] + _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + if is_torch_available(): + from .adapter import MultiAdapter, T2IAdapter, FullAdapter + from .autoencoders import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + AutoencoderKLCogVideoX, + AutoencoderKLTemporalDecoder, + AutoencoderOobleck, + AutoencoderTiny, + ConsistencyDecoderVAE, + VQModel, + ) + from .branch_cogvideox import CogvideoXBranchModel + from .controlnet import ControlNetModel + from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel + from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel + from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel + from .controlnet_sparsectrl import SparseControlNetModel + from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel + from .embeddings import ImageProjection + from .modeling_utils import ModelMixin + from .transformers import ( + AuraFlowTransformer2DModel, + CogVideoXTransformer3DModel, + CogVideoXTransformer3DInpaintModel, + DiTTransformer2DModel, + DualTransformer2DModel, + FluxTransformer2DModel, + HunyuanDiT2DModel, + LatteTransformer3DModel, + LuminaNextDiT2DModel, + PixArtTransformer2DModel, + PriorTransformer, + SD3Transformer2DModel, + StableAudioDiTModel, + T5FilmDecoder, + Transformer2DModel, + TransformerTemporalModel, + ) + from .unets import ( + I2VGenXLUNet, + Kandinsky3UNet, + MotionAdapter, + StableCascadeUNet, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + UNet3DConditionModel, + UNetMotionModel, + UNetSpatioTemporalConditionModel, + UVit2DModel, + ) + + if is_flax_available(): + from .controlnet_flax import FlaxControlNetModel + from .unets import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/diffusers/src/diffusers/models/activations.py b/diffusers/src/diffusers/models/activations.py new file mode 100755 index 0000000..fb24a36 --- /dev/null +++ b/diffusers/src/diffusers/models/activations.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import deprecate +from ..utils.import_utils import is_torch_npu_available + + +if is_torch_npu_available(): + import torch_npu + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + # using torch_npu.npu_geglu can run faster and save memory on NPU. + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class SwiGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU` + but uses SiLU / Swish instead of GeLU. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) diff --git a/diffusers/src/diffusers/models/adapter.py b/diffusers/src/diffusers/models/adapter.py new file mode 100755 index 0000000..fdf93c6 --- /dev/null +++ b/diffusers/src/diffusers/models/adapter.py @@ -0,0 +1,586 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, List, Optional, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging +from .modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) + + +class MultiAdapter(ModelMixin): + r""" + MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to + user-assigned weighting. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + adapters (`List[T2IAdapter]`, *optional*, defaults to None): + A list of `T2IAdapter` model instances. + """ + + def __init__(self, adapters: List["T2IAdapter"]): + super(MultiAdapter, self).__init__() + + self.num_adapter = len(adapters) + self.adapters = nn.ModuleList(adapters) + + if len(adapters) == 0: + raise ValueError("Expecting at least one adapter") + + if len(adapters) == 1: + raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`") + + # The outputs from each adapter are added together with a weight. + # This means that the change in dimensions from downsampling must + # be the same for all adapters. Inductively, it also means the + # downscale_factor and total_downscale_factor must be the same for all + # adapters. + first_adapter_total_downscale_factor = adapters[0].total_downscale_factor + first_adapter_downscale_factor = adapters[0].downscale_factor + for idx in range(1, len(adapters)): + if ( + adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor + or adapters[idx].downscale_factor != first_adapter_downscale_factor + ): + raise ValueError( + f"Expecting all adapters to have the same downscaling behavior, but got:\n" + f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n" + f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n" + f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n" + f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}" + ) + + self.total_downscale_factor = first_adapter_total_downscale_factor + self.downscale_factor = first_adapter_downscale_factor + + def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: + r""" + Args: + xs (`torch.Tensor`): + (batch, channel, height, width) input images for multiple adapter models concated along dimension 1, + `channel` should equal to `num_adapter` * "number of channel of image". + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding + them together. + """ + if adapter_weights is None: + adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) + else: + adapter_weights = torch.tensor(adapter_weights) + + accume_state = None + for x, w, adapter in zip(xs, adapter_weights, self.adapters): + features = adapter(x) + if accume_state is None: + accume_state = features + for i in range(len(accume_state)): + accume_state[i] = w * accume_state[i] + else: + for i in range(len(features)): + accume_state[i] += w * features[i] + return accume_state + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~models.adapter.MultiAdapter.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + idx = 0 + model_path_to_save = save_directory + for adapter in self.adapters: + adapter.save_pretrained( + model_path_to_save, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + ) + + idx += 1 + model_path_to_save = model_path_to_save + f"_{idx}" + + @classmethod + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_path (`os.PathLike`): + A path to a *directory* containing model weights saved using + [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + """ + idx = 0 + adapters = [] + + # load adapter and append to list until no adapter directory exists anymore + # first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained` + # second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ... + model_path_to_load = pretrained_model_path + while os.path.isdir(model_path_to_load): + adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs) + adapters.append(adapter) + + idx += 1 + model_path_to_load = pretrained_model_path + f"_{idx}" + + logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.") + + if len(adapters) == 0: + raise ValueError( + f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." + ) + + return cls(adapters) + + +class T2IAdapter(ModelMixin, ConfigMixin): + r""" + A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model + generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's + architecture follows the original implementation of + [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97) + and + [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (`int`, *optional*, defaults to 3): + Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale + image as *control image*. + channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will + also determine the number of downsample blocks in the Adapter. + num_res_blocks (`int`, *optional*, defaults to 2): + Number of ResNet blocks in each downsample block. + downscale_factor (`int`, *optional*, defaults to 8): + A factor that determines the total downscale factor of the Adapter. + adapter_type (`str`, *optional*, defaults to `full_adapter`): + The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 8, + adapter_type: str = "full_adapter", + is_downsample: bool = True, + ): + super().__init__() + + if adapter_type == "full_adapter": + self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor, is_downsample) + elif adapter_type == "full_adapter_xl": + self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor) + elif adapter_type == "light_adapter": + self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor) + else: + raise ValueError( + f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or " + "'full_adapter_xl' or 'light_adapter'." + ) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + This function processes the input tensor `x` through the adapter model and returns a list of feature tensors, + each representing information extracted at a different scale from the input. The length of the list is + determined by the number of downsample blocks in the Adapter, as specified by the `channels` and + `num_res_blocks` parameters during initialization. + """ + return self.adapter(x) + + @property + def total_downscale_factor(self): + return self.adapter.total_downscale_factor + + @property + def downscale_factor(self): + """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are + not evenly divisible by the downscale_factor then an exception will be raised. + """ + return self.adapter.unshuffle.downscale_factor + + +# full adapter + + +class FullAdapter(nn.Module): + r""" + See [`T2IAdapter`] for more information. + """ + + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 8, + is_downsample: bool = True, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) + + self.body = nn.ModuleList( + [ + AdapterBlock(channels[0], channels[0], num_res_blocks), + *[ + AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=is_downsample) + for i in range(1, len(channels)) + ], + ] + ) + + self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + This method processes the input tensor `x` through the FullAdapter model and performs operations including + pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each + capturing information at a different stage of processing within the FullAdapter model. The number of feature + tensors in the list is determined by the number of downsample blocks specified during initialization. + """ + x = self.unshuffle(x) + x = self.conv_in(x) + + features = [] + + for block in self.body: + x = block(x) + features.append(x) + + return features + + +class FullAdapterXL(nn.Module): + r""" + See [`T2IAdapter`] for more information. + """ + + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 16, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) + + self.body = [] + # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32] + for i in range(len(channels)): + if i == 1: + self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks)) + elif i == 2: + self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)) + else: + self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks)) + + self.body = nn.ModuleList(self.body) + # XL has only one downsampling AdapterBlock. + self.total_downscale_factor = downscale_factor * 2 + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations + including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors. + """ + x = self.unshuffle(x) + x = self.conv_in(x) + + features = [] + + for block in self.body: + x = block(x) + features.append(x) + + return features + + +class AdapterBlock(nn.Module): + r""" + An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and + `FullAdapterXL` models. + + Parameters: + in_channels (`int`): + Number of channels of AdapterBlock's input. + out_channels (`int`): + Number of channels of AdapterBlock's output. + num_res_blocks (`int`): + Number of ResNet blocks in the AdapterBlock. + down (`bool`, *optional*, defaults to `False`): + Whether to perform downsampling on AdapterBlock's input. + """ + + def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): + super().__init__() + + self.downsample = None + if down: + self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) + + self.in_conv = None + if in_channels != out_channels: + self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + self.resnets = nn.Sequential( + *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + This method takes tensor x as input and performs operations downsampling and convolutional layers if the + self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of + residual blocks to the input tensor. + """ + if self.downsample is not None: + x = self.downsample(x) + + if self.in_conv is not None: + x = self.in_conv(x) + + x = self.resnets(x) + + return x + + +class AdapterResnetBlock(nn.Module): + r""" + An `AdapterResnetBlock` is a helper model that implements a ResNet-like block. + + Parameters: + channels (`int`): + Number of channels of AdapterResnetBlock's input and output. + """ + + def __init__(self, channels: int): + super().__init__() + self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(channels, channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional + layer on the input tensor. It returns addition with the input tensor. + """ + + h = self.act(self.block1(x)) + h = self.block2(h) + + return h + x + + +# light adapter + + +class LightAdapter(nn.Module): + r""" + See [`T2IAdapter`] for more information. + """ + + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280], + num_res_blocks: int = 4, + downscale_factor: int = 8, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + + self.body = nn.ModuleList( + [ + LightAdapterBlock(in_channels, channels[0], num_res_blocks), + *[ + LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True) + for i in range(len(channels) - 1) + ], + LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True), + ] + ) + + self.total_downscale_factor = downscale_factor * (2 ** len(channels)) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each + feature tensor corresponds to a different level of processing within the LightAdapter. + """ + x = self.unshuffle(x) + + features = [] + + for block in self.body: + x = block(x) + features.append(x) + + return features + + +class LightAdapterBlock(nn.Module): + r""" + A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the + `LightAdapter` model. + + Parameters: + in_channels (`int`): + Number of channels of LightAdapterBlock's input. + out_channels (`int`): + Number of channels of LightAdapterBlock's output. + num_res_blocks (`int`): + Number of LightAdapterResnetBlocks in the LightAdapterBlock. + down (`bool`, *optional*, defaults to `False`): + Whether to perform downsampling on LightAdapterBlock's input. + """ + + def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): + super().__init__() + mid_channels = out_channels // 4 + + self.downsample = None + if down: + self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) + + self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1) + self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) + self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + This method takes tensor x as input and performs downsampling if required. Then it applies in convolution + layer, a sequence of residual blocks, and out convolutional layer. + """ + if self.downsample is not None: + x = self.downsample(x) + + x = self.in_conv(x) + x = self.resnets(x) + x = self.out_conv(x) + + return x + + +class LightAdapterResnetBlock(nn.Module): + """ + A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different + architecture than `AdapterResnetBlock`. + + Parameters: + channels (`int`): + Number of channels of LightAdapterResnetBlock's input and output. + """ + + def __init__(self, channels: int): + super().__init__() + self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and + another convolutional layer and adds it to input tensor. + """ + + h = self.act(self.block1(x)) + h = self.block2(h) + + return h + x diff --git a/diffusers/src/diffusers/models/attention.py b/diffusers/src/diffusers/models/attention.py new file mode 100755 index 0000000..84db0d0 --- /dev/null +++ b/diffusers/src/diffusers/models/attention.py @@ -0,0 +1,1202 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import deprecate, logging +from ..utils.torch_utils import maybe_allow_in_graph +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU +from .attention_processor import Attention, JointAttnProcessor2_0 +from .embeddings import SinusoidalPositionalEmbedding +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm + + +logger = logging.get_logger(__name__) + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +@maybe_allow_in_graph +class JointTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): + super().__init__() + + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" + + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + if hasattr(F, "scaled_dot_product_attention"): + processor = JointAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=context_pre_only, + bias=True, + processor=processor, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + ): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + inner_dim = int(2 * inner_dim / 3) + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.silu = FP32SiLU() + + def forward(self, x): + return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) + + +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: int, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + +@maybe_allow_in_graph +class FreeNoiseTransformerBlock(nn.Module): + r""" + A FreeNoise Transformer block. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (`int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (`bool`, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, defaults to `False`): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, defaults to `False`): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, defaults to `False`): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + ff_inner_dim (`int`, *optional*): + Hidden dimension of feed-forward MLP. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in feed-forward MLP. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in attention output project layer. + context_length (`int`, defaults to `16`): + The maximum number of frames that the FreeNoise block processes at once. + context_stride (`int`, defaults to `4`): + The number of frames to be skipped before starting to process a new batch of `context_length` frames. + weighting_scheme (`str`, defaults to `"pyramid"`): + The weighting scheme to use for weighting averaging of processed latent frames. As described in the + Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting + used. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + context_length: int = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + self.set_free_noise_properties(context_length, context_stride, weighting_scheme) + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: + frame_indices = [] + for i in range(0, num_frames - self.context_length + 1, self.context_stride): + window_start = i + window_end = min(num_frames, i + self.context_length) + frame_indices.append((window_start, window_end)) + return frame_indices + + def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: + if weighting_scheme == "flat": + weights = [1.0] * num_frames + + elif weighting_scheme == "pyramid": + if num_frames % 2 == 0: + # num_frames = 4 => [1, 2, 2, 1] + mid = num_frames // 2 + weights = list(range(1, mid + 1)) + weights = weights + weights[::-1] + else: + # num_frames = 5 => [1, 2, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = list(range(1, mid)) + weights = weights + [mid] + weights[::-1] + + elif weighting_scheme == "delayed_reverse_sawtooth": + if num_frames % 2 == 0: + # num_frames = 4 => [0.01, 2, 2, 1] + mid = num_frames // 2 + weights = [0.01] * (mid - 1) + [mid] + weights = weights + list(range(mid, 0, -1)) + else: + # num_frames = 5 => [0.01, 0.01, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = [0.01] * mid + weights = weights + list(range(mid, 0, -1)) + else: + raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") + + return weights + + def set_free_noise_properties( + self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" + ) -> None: + self.context_length = context_length + self.context_stride = context_stride + self.weighting_scheme = weighting_scheme + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + # hidden_states: [B x H x W, F, C] + device = hidden_states.device + dtype = hidden_states.dtype + + num_frames = hidden_states.size(1) + frame_indices = self._get_frame_indices(num_frames) + frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) + frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) + is_last_frame_batch_complete = frame_indices[-1][1] == num_frames + + # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length + # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: + # [(0, 16), (4, 20), (8, 24), (10, 26)] + if not is_last_frame_batch_complete: + if num_frames < self.context_length: + raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") + last_frame_batch_length = num_frames - frame_indices[-1][1] + frame_indices.append((num_frames - self.context_length, num_frames)) + + num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) + accumulated_values = torch.zeros_like(hidden_states) + + for i, (frame_start, frame_end) in enumerate(frame_indices): + # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle + # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or + # essentially a non-multiple of `context_length`. + weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) + weights *= frame_weights + + hidden_states_chunk = hidden_states[:, frame_start:frame_end] + + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states_chunk) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states_chunk = attn_output + hidden_states_chunk + if hidden_states_chunk.ndim == 4: + hidden_states_chunk = hidden_states_chunk.squeeze(1) + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states_chunk) + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states_chunk = attn_output + hidden_states_chunk + + if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: + accumulated_values[:, -last_frame_batch_length:] += ( + hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] + ) + num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] + else: + accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights + num_times_accumulated[:, frame_start:frame_end] += weights + + # TODO(aryan): Maybe this could be done in a better way. + # + # Previously, this was: + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ) + # + # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory + # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes + # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly + # looked into this deeply because other memory optimizations led to more pronounced reductions. + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, + ).to(dtype) + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/diffusers/src/diffusers/models/attention_flax.py b/diffusers/src/diffusers/models/attention_flax.py new file mode 100755 index 0000000..25ae5d0 --- /dev/null +++ b/diffusers/src/diffusers/models/attention_flax.py @@ -0,0 +1,494 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math + +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): + """Multi-head dot product attention with a limited number of queries.""" + num_kv, num_heads, k_features = key.shape[-3:] + v_features = value.shape[-1] + key_chunk_size = min(key_chunk_size, num_kv) + query = query / jnp.sqrt(k_features) + + @functools.partial(jax.checkpoint, prevent_cse=False) + def summarize_chunk(query, key, value): + attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) + + max_score = jnp.max(attn_weights, axis=-1, keepdims=True) + max_score = jax.lax.stop_gradient(max_score) + exp_weights = jnp.exp(attn_weights - max_score) + + exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) + max_score = jnp.einsum("...qhk->...qh", max_score) + + return (exp_values, exp_weights.sum(axis=-1), max_score) + + def chunk_scanner(chunk_idx): + # julienne key array + key_chunk = jax.lax.dynamic_slice( + operand=key, + start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] + slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] + ) + + # julienne value array + value_chunk = jax.lax.dynamic_slice( + operand=value, + start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] + slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] + ) + + return summarize_chunk(query, key_chunk, value_chunk) + + chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) + + global_max = jnp.max(chunk_max, axis=0, keepdims=True) + max_diffs = jnp.exp(chunk_max - global_max) + + chunk_values *= jnp.expand_dims(max_diffs, axis=-1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(axis=0) + all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) + + return all_values / all_weights + + +def jax_memory_efficient_attention( + query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 +): + r""" + Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 + https://github.com/AminRezaei0x443/memory-efficient-attention + + Args: + query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) + key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) + value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) + precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): + numerical precision for computation + query_chunk_size (`int`, *optional*, defaults to 1024): + chunk size to divide query array value must divide query_length equally without remainder + key_chunk_size (`int`, *optional*, defaults to 4096): + chunk size to divide key and value array value must divide key_value_length equally without remainder + + Returns: + (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) + """ + num_q, num_heads, q_features = query.shape[-3:] + + def chunk_scanner(chunk_idx, _): + # julienne query array + query_chunk = jax.lax.dynamic_slice( + operand=query, + start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] + slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] + ) + + return ( + chunk_idx + query_chunk_size, # unused ignore it + _query_chunk_attention( + query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size + ), + ) + + _, res = jax.lax.scan( + f=chunk_scanner, + init=0, + xs=None, + length=math.ceil(num_q / query_chunk_size), # start counter # stop counter + ) + + return jnp.concatenate(res, axis=-3) # fuse the chunked result back + + +class FlaxAttention(nn.Module): + r""" + A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 + + Parameters: + query_dim (:obj:`int`): + Input hidden states dimension + heads (:obj:`int`, *optional*, defaults to 8): + Number of heads + dim_head (:obj:`int`, *optional*, defaults to 64): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ + + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim_head * self.heads + self.scale = self.dim_head**-0.5 + + # Weights were exported with old names {to_q, to_k, to_v, to_out} + self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") + self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") + self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") + + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def __call__(self, hidden_states, context=None, deterministic=True): + context = hidden_states if context is None else context + + query_proj = self.query(hidden_states) + key_proj = self.key(context) + value_proj = self.value(context) + + if self.split_head_dim: + b = hidden_states.shape[0] + query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head)) + key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head)) + value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head)) + else: + query_states = self.reshape_heads_to_batch_dim(query_proj) + key_states = self.reshape_heads_to_batch_dim(key_proj) + value_states = self.reshape_heads_to_batch_dim(value_proj) + + if self.use_memory_efficient_attention: + query_states = query_states.transpose(1, 0, 2) + key_states = key_states.transpose(1, 0, 2) + value_states = value_states.transpose(1, 0, 2) + + # this if statement create a chunk size for each layer of the unet + # the chunk size is equal to the query_length dimension of the deepest layer of the unet + + flatten_latent_dim = query_states.shape[-3] + if flatten_latent_dim % 64 == 0: + query_chunk_size = int(flatten_latent_dim / 64) + elif flatten_latent_dim % 16 == 0: + query_chunk_size = int(flatten_latent_dim / 16) + elif flatten_latent_dim % 4 == 0: + query_chunk_size = int(flatten_latent_dim / 4) + else: + query_chunk_size = int(flatten_latent_dim) + + hidden_states = jax_memory_efficient_attention( + query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + ) + + hidden_states = hidden_states.transpose(1, 0, 2) + else: + # compute attentions + if self.split_head_dim: + attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) + else: + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2) + + # attend to values + if self.split_head_dim: + hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) + b = hidden_states.shape[0] + hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head)) + else: + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + hidden_states = self.proj_attn(hidden_states) + return self.dropout_layer(hidden_states, deterministic=deterministic) + + +class FlaxBasicTransformerBlock(nn.Module): + r""" + A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: + https://arxiv.org/abs/1706.03762 + + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + only_cross_attention (`bool`, defaults to `False`): + Whether to only apply cross attention. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + """ + + dim: int + n_heads: int + d_head: int + dropout: float = 0.0 + only_cross_attention: bool = False + dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + + def setup(self): + # self attention (or cross_attention if only_cross_attention is True) + self.attn1 = FlaxAttention( + self.dim, + self.n_heads, + self.d_head, + self.dropout, + self.use_memory_efficient_attention, + self.split_head_dim, + dtype=self.dtype, + ) + # cross attention + self.attn2 = FlaxAttention( + self.dim, + self.n_heads, + self.d_head, + self.dropout, + self.use_memory_efficient_attention, + self.split_head_dim, + dtype=self.dtype, + ) + self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) + self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def __call__(self, hidden_states, context, deterministic=True): + # self attention + residual = hidden_states + if self.only_cross_attention: + hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) + else: + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + # cross attention + residual = hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = hidden_states + residual + + # feed forward + residual = hidden_states + hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + return self.dropout_layer(hidden_states, deterministic=deterministic) + + +class FlaxTransformer2DModel(nn.Module): + r""" + A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: + https://arxiv.org/pdf/1506.02025.pdf + + + Parameters: + in_channels (:obj:`int`): + Input number of channels + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + depth (:obj:`int`, *optional*, defaults to 1): + Number of transformers block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + use_linear_projection (`bool`, defaults to `False`): tbd + only_cross_attention (`bool`, defaults to `False`): tbd + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + """ + + in_channels: int + n_heads: int + d_head: int + depth: int = 1 + dropout: float = 0.0 + use_linear_projection: bool = False + only_cross_attention: bool = False + dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + + def setup(self): + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) + + inner_dim = self.n_heads * self.d_head + if self.use_linear_projection: + self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.transformer_blocks = [ + FlaxBasicTransformerBlock( + inner_dim, + self.n_heads, + self.d_head, + dropout=self.dropout, + only_cross_attention=self.only_cross_attention, + dtype=self.dtype, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + ) + for _ in range(self.depth) + ] + + if self.use_linear_projection: + self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def __call__(self, hidden_states, context, deterministic=True): + batch, height, width, channels = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + if self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height * width, channels) + hidden_states = self.proj_in(hidden_states) + else: + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.reshape(batch, height * width, channels) + + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) + + if self.use_linear_projection: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, channels) + else: + hidden_states = hidden_states.reshape(batch, height, width, channels) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states + residual + return self.dropout_layer(hidden_states, deterministic=deterministic) + + +class FlaxFeedForward(nn.Module): + r""" + Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's + [`FeedForward`] class, with the following simplifications: + - The activation function is currently hardcoded to a gated linear unit from: + https://arxiv.org/abs/2002.05202 + - `dim_out` is equal to `dim`. + - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # The second linear layer needs to be called + # net_2 for now to match the index of the Sequential layer + self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) + self.net_2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.net_0(hidden_states, deterministic=deterministic) + hidden_states = self.net_2(hidden_states) + return hidden_states + + +class FlaxGEGLU(nn.Module): + r""" + Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from + https://arxiv.org/abs/2002.05202. + + Parameters: + dim (:obj:`int`): + Input hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim * 4 + self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.proj(hidden_states) + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) + return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) diff --git a/diffusers/src/diffusers/models/attention_processor.py b/diffusers/src/diffusers/models/attention_processor.py new file mode 100755 index 0000000..12f97a9 --- /dev/null +++ b/diffusers/src/diffusers/models/attention_processor.py @@ -0,0 +1,4724 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..image_processor import IPAdapterMaskProcessor +from ..utils import deprecate, logging +from ..utils.import_utils import is_torch_npu_available, is_xformers_available +from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_torch_npu_available(): + import torch_npu + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + ): + super().__init__() + + # To prevent circular import. + from .normalization import FP32LayerNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applys qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and is_custom_diffusion: + raise NotImplementedError( + f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # handle added projections for SD3 and others. + if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomDiffusionAttnProcessor(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) + else: + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class AttnAddedKVProcessor: + r""" + Processor for performing attention-related computations with extra learnable key and value matrices for the text + encoder. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class AttnAddedKVProcessor2_0: + r""" + Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra + learnable key and value matrices for the text encoder. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query, out_dim=4) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key, out_dim=4) + value = attn.head_to_batch_dim(value, out_dim=4) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + +class PAGJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # store the length of image patch sequences to create a mask that prevents interaction between patches + # similar to making the self-attention map an identity matrix + identity_block_size = hidden_states.shape[1] + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2) + + ################## original path ################## + batch_size = encoder_hidden_states_org.shape[0] + + # `sample` projections. + query_org = attn.to_q(hidden_states_org) + key_org = attn.to_k(hidden_states_org) + value_org = attn.to_v(hidden_states_org) + + # `context` projections. + encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org) + encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org) + encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org) + + # attention + query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1) + key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1) + value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1) + + inner_dim = key_org.shape[-1] + head_dim = inner_dim // attn.heads + query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states_org = F.scaled_dot_product_attention( + query_org, key_org, value_org, dropout_p=0.0, is_causal=False + ) + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query_org.dtype) + + # Split the attention outputs. + hidden_states_org, encoder_hidden_states_org = ( + hidden_states_org[:, : residual.shape[1]], + hidden_states_org[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + if not attn.context_pre_only: + encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################## perturbed path ################## + + batch_size = encoder_hidden_states_ptb.shape[0] + + # `sample` projections. + query_ptb = attn.to_q(hidden_states_ptb) + key_ptb = attn.to_k(hidden_states_ptb) + value_ptb = attn.to_v(hidden_states_ptb) + + # `context` projections. + encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb) + + # attention + query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1) + key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1) + value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1) + + inner_dim = key_ptb.shape[-1] + head_dim = inner_dim // attn.heads + query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # create a full mask with all entries set to 0 + seq_len = query_ptb.size(2) + full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype) + + # set the attention value between image patches to -inf + full_mask[:identity_block_size, :identity_block_size] = float("-inf") + + # set the diagonal of the attention value between image patches to 0 + full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) + + # expand the mask to match the attention weights shape + full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions + + hidden_states_ptb = F.scaled_dot_product_attention( + query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype) + + # split the attention outputs. + hidden_states_ptb, encoder_hidden_states_ptb = ( + hidden_states_ptb[:, : residual.shape[1]], + hidden_states_ptb[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + if not attn.context_pre_only: + encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################ concat ############### + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb]) + + return hidden_states, encoder_hidden_states + + +class PAGCFGJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + identity_block_size = hidden_states.shape[ + 1 + ] # patch embeddings width * height (correspond to self-attention map width or height) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + ( + encoder_hidden_states_uncond, + encoder_hidden_states_org, + encoder_hidden_states_ptb, + ) = encoder_hidden_states.chunk(3) + encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org]) + + ################## original path ################## + batch_size = encoder_hidden_states_org.shape[0] + + # `sample` projections. + query_org = attn.to_q(hidden_states_org) + key_org = attn.to_k(hidden_states_org) + value_org = attn.to_v(hidden_states_org) + + # `context` projections. + encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org) + encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org) + encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org) + + # attention + query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1) + key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1) + value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1) + + inner_dim = key_org.shape[-1] + head_dim = inner_dim // attn.heads + query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states_org = F.scaled_dot_product_attention( + query_org, key_org, value_org, dropout_p=0.0, is_causal=False + ) + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query_org.dtype) + + # Split the attention outputs. + hidden_states_org, encoder_hidden_states_org = ( + hidden_states_org[:, : residual.shape[1]], + hidden_states_org[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + if not attn.context_pre_only: + encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################## perturbed path ################## + + batch_size = encoder_hidden_states_ptb.shape[0] + + # `sample` projections. + query_ptb = attn.to_q(hidden_states_ptb) + key_ptb = attn.to_k(hidden_states_ptb) + value_ptb = attn.to_v(hidden_states_ptb) + + # `context` projections. + encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb) + + # attention + query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1) + key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1) + value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1) + + inner_dim = key_ptb.shape[-1] + head_dim = inner_dim // attn.heads + query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # create a full mask with all entries set to 0 + seq_len = query_ptb.size(2) + full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype) + + # set the attention value between image patches to -inf + full_mask[:identity_block_size, :identity_block_size] = float("-inf") + + # set the diagonal of the attention value between image patches to 0 + full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) + + # expand the mask to match the attention weights shape + full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions + + hidden_states_ptb = F.scaled_dot_product_attention( + query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype) + + # split the attention outputs. + hidden_states_ptb, encoder_hidden_states_ptb = ( + hidden_states_ptb[:, : residual.shape[1]], + hidden_states_ptb[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + if not attn.context_pre_only: + encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################ concat ############### + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb]) + + return hidden_states, encoder_hidden_states + + +class FusedJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # `context` projections. + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + +class AuraFlowAttnProcessor2_0: + """Attention processor used typically in processing Aura Flow.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedAuraFlowAttnProcessor2_0: + """Attention processor used typically in processing Aura Flow with fused projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + +class FluxAttnProcessor2_0_NPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + + +class FusedFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedFluxAttnProcessor2_0_NPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + resample_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + prev_hidden_states: Optional[torch.Tensor] = None, + prev_clip_weight: Optional[float] = None, + prev_resample_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + sequence_length = hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + if prev_hidden_states is not None and prev_clip_weight is not None and prev_clip_weight > 0.0: + prev_key = attn.to_k(prev_hidden_states) + prev_value = attn.to_v(prev_hidden_states) + + prev_key = prev_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + prev_value = prev_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_k is not None: + prev_key = attn.norm_k(prev_key) + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + if not attn.is_cross_attention: + prev_key[:, :, text_seq_length:] = apply_rotary_emb(prev_key[:, :, text_seq_length:], image_rotary_emb) + else: + prev_key = None + prev_value = None + + + if prev_key is not None and prev_value is not None: + hidden_states = F.scaled_dot_product_attention( + query, key, value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) * (1 - prev_clip_weight) + + hidden_states = hidden_states + F.scaled_dot_product_attention( + query, prev_key, prev_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) * prev_clip_weight + + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + + +class CogVideoXAttnProcessor2_0_resample: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + resample_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + prev_hidden_states: Optional[torch.Tensor] = None, + prev_clip_weight: Optional[float] = None, + prev_resample_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + query_org = attn.to_q(hidden_states) + key_org = attn.to_k(hidden_states) + value_org = attn.to_v(hidden_states) + + if prev_hidden_states is not None and prev_clip_weight is not None and prev_clip_weight > 0.0: + # print(f"CogVideoXAttnProcessor2_0_resample: prev_resample_mask: {prev_resample_mask.shape}, prev_clip_weight: {prev_clip_weight}, prev_hidden_states: {prev_hidden_states.shape}") + prev_key = attn.to_k(prev_hidden_states) + prev_value = attn.to_v(prev_hidden_states) + key_mask = prev_key * prev_resample_mask.unsqueeze(-1) * prev_clip_weight + value_mask = prev_value * prev_resample_mask.unsqueeze(-1) * prev_clip_weight + else: + # print(f"CogVideoXAttnProcessor2_0_resample: resample_mask: {resample_mask.shape}, prev_clip_weight: {prev_clip_weight}, hidden_states: {hidden_states.shape}") + key_mask = key_org * resample_mask.unsqueeze(-1) + value_mask = value_org * resample_mask.unsqueeze(-1) + + inner_dim = key_org.shape[-1] + head_dim = inner_dim // attn.heads + + query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_mask = key_mask.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_mask = value_mask.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + + if attn.norm_q is not None: + query_org = attn.norm_q(query_org) + if attn.norm_k is not None: + key_org = attn.norm_k(key_org) + key_mask = attn.norm_k(key_mask) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query_org[:, :, text_seq_length:] = apply_rotary_emb(query_org[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key_org[:, :, text_seq_length:] = apply_rotary_emb(key_org[:, :, text_seq_length:], image_rotary_emb) + key_mask[:, :, text_seq_length:] = apply_rotary_emb(key_mask[:, :, text_seq_length:], image_rotary_emb) + + key_org = torch.cat([key_org, key_mask], dim=-2) + value_org = torch.cat([value_org, value_mask], dim=-2) + hidden_states = F.scaled_dot_product_attention( + query_org, key_org, value_org, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + +class CogVideoXAttnProcessor2_0_wo_text: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + batch_size, sequence_length, _ = hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, :] = apply_rotary_emb(query[:, :, :], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, :] = apply_rotary_emb(key[:, :, :], image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + +class FusedCogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class XFormersAttnAddedKVProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersAttnProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, key_tokens, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessorNPU: + r""" + Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If + fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is + not significant. + + """ + + def __init__(self): + if not is_torch_npu_available(): + raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def apply_partial_rotary_emb( + self, + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor], + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + rot_dim = freqs_cis[0].shape[-1] + x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] + + x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) + + out = torch.cat((x_rotated, x_unrotated), dim=-1) + return out + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + + rot_dim = rotary_emb[0].shape[-1] + query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] + query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + query = query.to(query_dtype) + key = key.to(key_dtype) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class HunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FusedHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused + projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on + query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LuminaAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.view(batch_size, -1, attn.heads, head_dim) + + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply RoPE if needed + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb, use_real=False) + if key_rotary_emb is not None: + key = apply_rotary_emb(key, key_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Apply proportional attention if true + if key_rotary_emb is None: + softmax_scale = None + else: + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep >= 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).to(dtype) + + return hidden_states + + +class FusedAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses + fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. + For cross-attention modules, key and value projection matrices are fused. + + + + This API is currently 🧪 experimental in nature and can change in future. + + + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomDiffusionXFormersAttnProcessor(nn.Module): + r""" + Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use + as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = False, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + attention_op: Optional[Callable] = None, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.attention_op = attention_op + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) + else: + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class CustomDiffusionAttnProcessor2_0(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0's memory-efficient scaled + dot-product attention. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + inner_dim = hidden_states.shape[-1] + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SlicedAttnProcessor: + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size: int): + self.slice_size = slice_size + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range((batch_size_attention - 1) // self.slice_size + 1): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class SlicedAttnAddedKVProcessor: + r""" + Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__( + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range((batch_size_attention - 1) // self.slice_size + 1): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class IPAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Multiple IP-Adapters. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or List[`float`], defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, + ): + residual = hidden_states + + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + + hidden_states = hidden_states + scale * current_ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAdapterAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapter for PyTorch 2.0. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or `List[float]`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, + ): + residual = hidden_states + + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + scale * current_ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGIdentitySelfAttnProcessor2_0: + r""" + Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + PAG reference: https://arxiv.org/abs/2403.17377 + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGIdentitySelfAttnProcessor2_0: + r""" + Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + PAG reference: https://arxiv.org/abs/2403.17377 + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + hidden_states_ptb = value + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAAttnProcessor: + def __init__(self): + pass + + +class LoRAAttnProcessor2_0: + def __init__(self): + pass + + +class LoRAXFormersAttnProcessor: + def __init__(self): + pass + + +class LoRAAttnAddedKVProcessor: + def __init__(self): + pass + + +class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." + deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message) + super().__init__() + + +ADDED_KV_ATTENTION_PROCESSORS = ( + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, +) + +CROSS_ATTENTION_PROCESSORS = ( + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) + +AttentionProcessor = Union[ + AttnProcessor, + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + CustomDiffusionAttnProcessor2_0, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, + PAGCFGHunyuanAttnProcessor2_0, + PAGHunyuanAttnProcessor2_0, + FluxAttnProcessor2_0, + FluxAttnProcessor2_0_NPU, + FusedFluxAttnProcessor2_0, + FusedFluxAttnProcessor2_0_NPU, +] diff --git a/diffusers/src/diffusers/models/autoencoders/__init__.py b/diffusers/src/diffusers/models/autoencoders/__init__.py new file mode 100755 index 0000000..ccf4552 --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/__init__.py @@ -0,0 +1,8 @@ +from .autoencoder_asym_kl import AsymmetricAutoencoderKL +from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder +from .autoencoder_oobleck import AutoencoderOobleck +from .autoencoder_tiny import AutoencoderTiny +from .consistency_decoder_vae import ConsistencyDecoderVAE +from .vq_model import VQModel diff --git a/diffusers/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/diffusers/src/diffusers/models/autoencoders/autoencoder_asym_kl.py new file mode 100755 index 0000000..3f4d465 --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -0,0 +1,184 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder + + +class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): + r""" + Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss + for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of down block output channels. + layers_per_down_block (`int`, *optional*, defaults to `1`): + Number layers for down block. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of up block output channels. + layers_per_up_block (`int`, *optional*, defaults to `1`): + Number layers for up block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + norm_num_groups (`int`, *optional*, defaults to `32`): + Number of groups to use for the first normalization layer in ResNet blocks. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + down_block_out_channels: Tuple[int, ...] = (64,), + layers_per_down_block: int = 1, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + up_block_out_channels: Tuple[int, ...] = (64,), + layers_per_up_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + ) -> None: + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=down_block_out_channels, + layers_per_block=layers_per_down_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = MaskConditionDecoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=up_block_out_channels, + layers_per_block=layers_per_up_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + self.register_to_config(block_out_channels=up_block_out_channels) + self.register_to_config(force_upcast=False) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + image: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + z = self.post_quant_conv(z) + dec = self.decoder(z, image, mask) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + generator: Optional[torch.Generator] = None, + image: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + decoded = self._decode(z, image, mask).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + mask: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, generator, sample, mask).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py b/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py new file mode 100755 index 0000000..161770c --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -0,0 +1,510 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + if self.quant_conv is not None: + moments = self.quant_conv(h) + else: + moments = h + + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py new file mode 100755 index 0000000..7f6c4c1 --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -0,0 +1,1376 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..downsampling import CogVideoXDownsample3D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..upsampling import CogVideoXUpsample3D +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogVideoXSafeConv3d(nn.Conv3d): + r""" + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + + # Set to 2GB, suitable for CuDNN + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) + + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super().forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super().forward(input) + + +class CogVideoXCausalConv3d(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + + Args: + in_channels (`int`): Number of channels in the input tensor. + out_channels (`int`): Number of output channels produced by the convolution. + kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. + stride (`int`, defaults to `1`): Stride of the convolution. + dilation (`int`, defaults to `1`): Dilation rate of the convolution. + pad_mode (`str`, defaults to `"constant"`): Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: int = 1, + dilation: int = 1, + pad_mode: str = "constant", + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + self.temporal_dim = 2 + self.time_kernel_size = time_kernel_size + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = CogVideoXSafeConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + self.conv_cache = None + + def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: + kernel_size = self.time_kernel_size + if kernel_size > 1: + cached_inputs = ( + [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + ) + inputs = torch.cat(cached_inputs + [inputs], dim=2) + return inputs + + def _clear_fake_context_parallel_cache(self): + del self.conv_cache + self.conv_cache = None + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = self.fake_context_parallel_forward(inputs) + + self._clear_fake_context_parallel_cache() + # Note: we could move these to the cpu for a lower maximum memory usage but its only a few + # hundred megabytes and so let's not do it for now + self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + inputs = F.pad(inputs, padding_2d, mode="constant", value=0) + + output = self.conv(inputs) + return output + + +class CogVideoXSpatialNorm3D(nn.Module): + r""" + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific + to 3D-video like data. + + CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + groups (`int`): + Number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + groups: int = 32, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = F.interpolate(z_first, size=f_first_size) + z_rest = F.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = F.interpolate(zq, size=f.shape[-3:]) + + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class CogVideoXResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=spatial_norm_dim, + groups=groups, + ) + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=spatial_norm_dim, + groups=groups, + ) + + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if temb_channels > 0: + self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) + + self.dropout = nn.Dropout(dropout) + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + else: + self.conv_shortcut = CogVideoXSafeConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = inputs + + if zq is not None: + hidden_states = self.norm1(hidden_states, zq) + else: + hidden_states = self.norm1(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + hidden_states = self.norm2(hidden_states, zq) + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +class CogVideoXDownBlock3D(nn.Module): + r""" + A downsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + add_downsample (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 0, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.downsamplers = None + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + CogVideoXDownsample3D( + out_channels, out_channels, padding=downsample_padding, compress_time=compress_time + ) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq + ) + else: + hidden_states = resnet(hidden_states, temb, zq) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class CogVideoXMidBlock3D(nn.Module): + r""" + A middle block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + spatial_norm_dim=spatial_norm_dim, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq + ) + else: + hidden_states = resnet(hidden_states, temb, zq) + + return hidden_states + + +class CogVideoXUpBlock3D(nn.Module): + r""" + An upsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, defaults to `16`): + The dimension to use for spatial norm if it is to be used instead of group norm. + add_upsample (`bool`, defaults to `True`): + Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: int = 16, + add_upsample: bool = True, + upsample_padding: int = 1, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_norm_dim=spatial_norm_dim, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.upsamplers = None + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + CogVideoXUpsample3D( + out_channels, out_channels, padding=upsample_padding, compress_time=compress_time + ) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXUpBlock3D` class.""" + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq + ) + else: + hidden_states = resnet(hidden_states, temb, zq) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CogVideoXEncoder3D(nn.Module): + r""" + The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + # log2 of temporal_compress_times + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.down_blocks = nn.ModuleList([]) + + # down blocks + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if down_block_type == "CogVideoXDownBlock3D": + down_block = CogVideoXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + compress_time=compress_time, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=block_out_channels[-1], + temb_channels=0, + dropout=dropout, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + r"""The forward method of the `CogVideoXEncoder3D` class.""" + hidden_states = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Down + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, temb, None + ) + + # 2. Mid + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, None + ) + else: + # 1. Down + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, temb, None) + + # 2. Mid + hidden_states = self.mid_block(hidden_states, temb, None) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class CogVideoXDecoder3D(nn.Module): + r""" + The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + reversed_block_out_channels = list(reversed(block_out_channels)) + + self.conv_in = CogVideoXCausalConv3d( + in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode + ) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=reversed_block_out_channels[0], + temb_channels=0, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + pad_mode=pad_mode, + ) + + # up blocks + self.up_blocks = nn.ModuleList([]) + + output_channel = reversed_block_out_channels[0] + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if up_block_type == "CogVideoXUpBlock3D": + up_block = CogVideoXUpBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final_block, + compress_time=compress_time, + pad_mode=pad_mode, + ) + prev_output_channel = output_channel + else: + raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") + + self.up_blocks.append(up_block) + + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + r"""The forward method of the `CogVideoXDecoder3D` class.""" + hidden_states = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Mid + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, sample + ) + + # 2. Up + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, temb, sample + ) + else: + # 1. Mid + hidden_states = self.mid_block(hidden_states, temb, sample) + + # 2. Up + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, sample) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [CogVideoX](https://github.com/THUDM/CogVideo). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["CogVideoXResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types: Tuple[str] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 256, 512), + latent_channels: int = 16, + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_height: int = 480, + sample_width: int = 720, + scaling_factor: float = 1.15258426, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + ): + super().__init__() + + self.encoder = CogVideoXEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = CogVideoXDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = False + + # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not + # recommended because the temporal parts of the VAE, here, are tricky to understand. + # If you decode X latent frames together, the number of output frames is: + # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames + # + # Example with num_latent_frames_batch_size = 2: + # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together + # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 6 * 8 = 48 frames + # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together + # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + + # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 1 * 9 + 5 * 8 = 49 frames + # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that + # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different + # number of temporal frames. + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + # We make the minimum height and width of sample for tiling half that of the generally supported + self.tile_sample_min_height = sample_height // 2 + self.tile_sample_min_width = sample_width // 2 + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + + # These are experimental overlap factors that were chosen based on experimentation and seem to work best for + # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX + # and so the tiling implementation has only been tested on those specific resolutions. + self.tile_overlap_factor_height = 1 / 6 + self.tile_overlap_factor_width = 1 / 5 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): + module.gradient_checkpointing = value + + def _clear_fake_context_parallel_cache(self): + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") + module._clear_fake_context_parallel_cache() + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_overlap_factor_height: Optional[float] = None, + tile_overlap_factor_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + frame_batch_size = self.num_sample_frames_batch_size + # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. + num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + enc = [] + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate = self.encoder(x_intermediate) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + self._clear_fake_context_parallel_cache() + enc = torch.cat(enc, dim=2) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + frame_batch_size = self.num_latent_frames_batch_size + num_batches = num_frames // frame_batch_size + dec = [] + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate = self.decoder(z_intermediate) + dec.append(z_intermediate) + + self._clear_fake_context_parallel_cache() + dec = torch.cat(dec, dim=2) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if z.shape[2] == 1: # Check if there's only 1 frame in the temporal dimension + z = torch.cat([z, z], dim=2) # Duplicate the frame + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + # For a rough memory estimate, take a look at the `tiled_decode` method. + batch_size, num_channels, num_frames, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = self.num_sample_frames_batch_size + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. + num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + time = [] + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[ + :, + :, + start_frame:end_frame, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3) + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = num_frames // frame_batch_size + time = [] + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile = self.decoder(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py new file mode 100755 index 0000000..5544964 --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -0,0 +1,401 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version +from ...utils.accelerate_utils import apply_forward_hook +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class TemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + block_out_channels: Tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + self.mid_block = MidBlockTemporalDecoder( + num_layers=self.layers_per_block, + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + attention_head_dim=block_out_channels[-1], + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + up_block = UpBlockTemporalDecoder( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = torch.nn.Conv2d( + in_channels=block_out_channels[0], + out_channels=out_channels, + kernel_size=3, + padding=1, + ) + + conv_out_kernel_size = (3, 1, 1) + padding = [int(k // 2) for k in conv_out_kernel_size] + self.time_conv_out = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=conv_out_kernel_size, + padding=padding, + ) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + image_only_indicator: torch.Tensor, + num_frames: int = 1, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + image_only_indicator, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + image_only_indicator, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + image_only_indicator, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + image_only_indicator, + ) + else: + # middle + sample = self.mid_block(sample, image_only_indicator=image_only_indicator) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, image_only_indicator=image_only_indicator) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + batch_frames, channels, height, width = sample.shape + batch_size = batch_frames // num_frames + sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + sample = self.time_conv_out(sample) + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + + return sample + + +class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + latent_channels: int = 4, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = TemporalDecoder( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, TemporalDecoder)): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain + tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + num_frames: int, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + batch_size = z.shape[0] // num_frames + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device) + decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + num_frames: int = 1, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + + dec = self.decode(z, num_frames=num_frames).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/diffusers/src/diffusers/models/autoencoders/autoencoder_oobleck.py new file mode 100755 index 0000000..e8e372a --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/autoencoder_oobleck.py @@ -0,0 +1,464 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ...utils.torch_utils import randn_tensor +from ..modeling_utils import ModelMixin + + +class Snake1d(nn.Module): + """ + A 1-dimensional Snake activation function module. + """ + + def __init__(self, hidden_dim, logscale=True): + super().__init__() + self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + + self.alpha.requires_grad = True + self.beta.requires_grad = True + self.logscale = logscale + + def forward(self, hidden_states): + shape = hidden_states.shape + + alpha = self.alpha if not self.logscale else torch.exp(self.alpha) + beta = self.beta if not self.logscale else torch.exp(self.beta) + + hidden_states = hidden_states.reshape(shape[0], shape[1], -1) + hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) + hidden_states = hidden_states.reshape(shape) + return hidden_states + + +class OobleckResidualUnit(nn.Module): + """ + A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations. + """ + + def __init__(self, dimension: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + + self.snake1 = Snake1d(dimension) + self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) + self.snake2 = Snake1d(dimension) + self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) + + def forward(self, hidden_state): + """ + Forward pass through the residual unit. + + Args: + hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`): + Input tensor . + + Returns: + output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`) + Input tensor after passing through the residual unit. + """ + output_tensor = hidden_state + output_tensor = self.conv1(self.snake1(output_tensor)) + output_tensor = self.conv2(self.snake2(output_tensor)) + + padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2 + if padding > 0: + hidden_state = hidden_state[..., padding:-padding] + output_tensor = hidden_state + output_tensor + return output_tensor + + +class OobleckEncoderBlock(nn.Module): + """Encoder block used in Oobleck encoder.""" + + def __init__(self, input_dim, output_dim, stride: int = 1): + super().__init__() + + self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9) + self.snake1 = Snake1d(input_dim) + self.conv1 = weight_norm( + nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) + ) + + def forward(self, hidden_state): + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.snake1(self.res_unit3(hidden_state)) + hidden_state = self.conv1(hidden_state) + + return hidden_state + + +class OobleckDecoderBlock(nn.Module): + """Decoder block used in Oobleck decoder.""" + + def __init__(self, input_dim, output_dim, stride: int = 1): + super().__init__() + + self.snake1 = Snake1d(input_dim) + self.conv_t1 = weight_norm( + nn.ConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ) + ) + self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9) + + def forward(self, hidden_state): + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv_t1(hidden_state) + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.res_unit3(hidden_state) + + return hidden_state + + +class OobleckDiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.scale = parameters.chunk(2, dim=1) + self.std = nn.functional.softplus(self.scale) + 1e-4 + self.var = self.std * self.std + self.logvar = torch.log(self.var) + self.deterministic = deterministic + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean() + else: + normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var + var_ratio = self.var / other.var + logvar_diff = self.logvar - other.logvar + + kl = normalized_diff + var_ratio + logvar_diff - 1 + + kl = kl.sum(1).mean() + return kl + + def mode(self) -> torch.Tensor: + return self.mean + + +@dataclass +class AutoencoderOobleckOutput(BaseOutput): + """ + Output of AutoencoderOobleck encoding method. + + Args: + latent_dist (`OobleckDiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and standard deviation of + `OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents + from the distribution. + """ + + latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821 + + +@dataclass +class OobleckDecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.Tensor + + +class OobleckEncoder(nn.Module): + """Oobleck Encoder""" + + def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples): + super().__init__() + + strides = downsampling_ratios + channel_multiples = [1] + channel_multiples + + # Create first convolution + self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3)) + + self.block = [] + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride_index, stride in enumerate(strides): + self.block += [ + OobleckEncoderBlock( + input_dim=encoder_hidden_size * channel_multiples[stride_index], + output_dim=encoder_hidden_size * channel_multiples[stride_index + 1], + stride=stride, + ) + ] + + self.block = nn.ModuleList(self.block) + d_model = encoder_hidden_size * channel_multiples[-1] + self.snake1 = Snake1d(d_model) + self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)) + + def forward(self, hidden_state): + hidden_state = self.conv1(hidden_state) + + for module in self.block: + hidden_state = module(hidden_state) + + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv2(hidden_state) + + return hidden_state + + +class OobleckDecoder(nn.Module): + """Oobleck Decoder""" + + def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples): + super().__init__() + + strides = upsampling_ratios + channel_multiples = [1] + channel_multiples + + # Add first conv layer + self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) + + # Add upsampling + MRF blocks + block = [] + for stride_index, stride in enumerate(strides): + block += [ + OobleckDecoderBlock( + input_dim=channels * channel_multiples[len(strides) - stride_index], + output_dim=channels * channel_multiples[len(strides) - stride_index - 1], + stride=stride, + ) + ] + + self.block = nn.ModuleList(block) + output_dim = channels + self.snake1 = Snake1d(output_dim) + self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) + + def forward(self, hidden_state): + hidden_state = self.conv1(hidden_state) + + for layer in self.block: + hidden_state = layer(hidden_state) + + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv2(hidden_state) + + return hidden_state + + +class AutoencoderOobleck(ModelMixin, ConfigMixin): + r""" + An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First + introduced in Stable Audio. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + encoder_hidden_size (`int`, *optional*, defaults to 128): + Intermediate representation dimension for the encoder. + downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`): + Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder. + channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`): + Multiples used to determine the hidden sizes of the hidden layers. + decoder_channels (`int`, *optional*, defaults to 128): + Intermediate representation dimension for the decoder. + decoder_input_channels (`int`, *optional*, defaults to 64): + Input dimension for the decoder. Corresponds to the latent dimension. + audio_channels (`int`, *optional*, defaults to 2): + Number of channels in the audio data. Either 1 for mono or 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 44100): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + encoder_hidden_size=128, + downsampling_ratios=[2, 4, 4, 8, 8], + channel_multiples=[1, 2, 4, 8, 16], + decoder_channels=128, + decoder_input_channels=64, + audio_channels=2, + sampling_rate=44100, + ): + super().__init__() + + self.encoder_hidden_size = encoder_hidden_size + self.downsampling_ratios = downsampling_ratios + self.decoder_channels = decoder_channels + self.upsampling_ratios = downsampling_ratios[::-1] + self.hop_length = int(np.prod(downsampling_ratios)) + self.sampling_rate = sampling_rate + + self.encoder = OobleckEncoder( + encoder_hidden_size=encoder_hidden_size, + audio_channels=audio_channels, + downsampling_ratios=downsampling_ratios, + channel_multiples=channel_multiples, + ) + + self.decoder = OobleckDecoder( + channels=decoder_channels, + input_channels=decoder_input_channels, + audio_channels=audio_channels, + upsampling_ratios=self.upsampling_ratios, + channel_multiples=channel_multiples, + ) + + self.use_slicing = False + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + posterior = OobleckDiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderOobleckOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]: + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return OobleckDecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[OobleckDecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.OobleckDecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple` + is returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return OobleckDecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[OobleckDecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return OobleckDecoderOutput(sample=dec) diff --git a/diffusers/src/diffusers/models/autoencoders/autoencoder_tiny.py b/diffusers/src/diffusers/models/autoencoders/autoencoder_tiny.py new file mode 100755 index 0000000..6e50347 --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/autoencoder_tiny.py @@ -0,0 +1,348 @@ +# Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DecoderTiny, EncoderTiny + + +@dataclass +class AutoencoderTinyOutput(BaseOutput): + """ + Output of AutoencoderTiny encoding method. + + Args: + latents (`torch.Tensor`): Encoded outputs of the `Encoder`. + + """ + + latents: torch.Tensor + + +class AutoencoderTiny(ModelMixin, ConfigMixin): + r""" + A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. + + [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Parameters: + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + Tuple of integers representing the number of output channels for each encoder block. The length of the + tuple should be equal to the number of encoder blocks. + decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + Tuple of integers representing the number of output channels for each decoder block. The length of the + tuple should be equal to the number of decoder blocks. + act_fn (`str`, *optional*, defaults to `"relu"`): + Activation function to be used throughout the model. + latent_channels (`int`, *optional*, defaults to 4): + Number of channels in the latent representation. The latent space acts as a compressed representation of + the input image. + upsampling_scaling_factor (`int`, *optional*, defaults to 2): + Scaling factor for upsampling in the decoder. It determines the size of the output image during the + upsampling process. + num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): + Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The + length of the tuple should be equal to the number of stages in the encoder. Each stage has a different + number of encoder blocks. + num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): + Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The + length of the tuple should be equal to the number of stages in the decoder. Each stage has a different + number of decoder blocks. + latent_magnitude (`float`, *optional*, defaults to 3.0): + Magnitude of the latent representation. This parameter scales the latent representation values to control + the extent of information preservation. + latent_shift (float, *optional*, defaults to 0.5): + Shift applied to the latent representation. This parameter controls the center of the latent space. + scaling_factor (`float`, *optional*, defaults to 1.0): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, + however, no such scaling factor was used, hence the value of 1.0 as the default. + force_upcast (`bool`, *optional*, default to `False`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision, in which case + `force_upcast` can be set to `False` (see this fp16-friendly + [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + act_fn: str = "relu", + upsample_fn: str = "nearest", + latent_channels: int = 4, + upsampling_scaling_factor: int = 2, + num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), + num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), + latent_magnitude: int = 3, + latent_shift: float = 0.5, + force_upcast: bool = False, + scaling_factor: float = 1.0, + shift_factor: float = 0.0, + ): + super().__init__() + + if len(encoder_block_out_channels) != len(num_encoder_blocks): + raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") + if len(decoder_block_out_channels) != len(num_decoder_blocks): + raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") + + self.encoder = EncoderTiny( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=num_encoder_blocks, + block_out_channels=encoder_block_out_channels, + act_fn=act_fn, + ) + + self.decoder = DecoderTiny( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=num_decoder_blocks, + block_out_channels=decoder_block_out_channels, + upsampling_scaling_factor=upsampling_scaling_factor, + act_fn=act_fn, + upsample_fn=upsample_fn, + ) + + self.latent_magnitude = latent_magnitude + self.latent_shift = latent_shift + self.scaling_factor = scaling_factor + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.spatial_scale_factor = 2**out_channels + self.tile_overlap_factor = 0.125 + self.tile_sample_min_size = 512 + self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor + + self.register_to_config(block_out_channels=decoder_block_out_channels) + self.register_to_config(force_upcast=False) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (EncoderTiny, DecoderTiny)): + module.gradient_checkpointing = value + + def scale_latents(self, x: torch.Tensor) -> torch.Tensor: + """raw latents -> [0, 1]""" + return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) + + def unscale_latents(self, x: torch.Tensor) -> torch.Tensor: + """[0, 1] -> raw latents""" + return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def enable_tiling(self, use_tiling: bool = True) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: Encoded batch of images. + """ + # scale of encoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_sample_min_size + + # number of pixels to blend and to traverse between tile + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf] + tile = self.encoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = blend_mask_i * blend_mask_j + tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + + def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: Encoded batch of images. + """ + # scale of decoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_latent_min_size + + # number of pixels to blend and to traverse between tiles + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf] + tile = self.decoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]: + if self.use_slicing and x.shape[0] > 1: + output = [ + self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1) + ] + output = torch.cat(output) + else: + output = self._tiled_encode(x) if self.use_tiling else self.encoder(x) + + if not return_dict: + return (output,) + + return AutoencoderTinyOutput(latents=output) + + @apply_forward_hook + def decode( + self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + if self.use_slicing and x.shape[0] > 1: + output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] + output = torch.cat(output) + else: + output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) + + if not return_dict: + return (output,) + + return DecoderOutput(sample=output) + + def forward( + self, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + enc = self.encode(sample).latents + + # scale latents to be in [0, 1], then quantize latents to a byte tensor, + # as if we were storing the latents in an RGBA uint8 image. + scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() + + # unquantize latents back into [0, 1], then unscale latents back to their original range, + # as if we were loading the latents from an RGBA uint8 image. + unscaled_enc = self.unscale_latents(scaled_enc / 255.0) + + dec = self.decode(unscaled_enc) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) diff --git a/diffusers/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/diffusers/src/diffusers/models/autoencoders/consistency_decoder_vae.py new file mode 100755 index 0000000..a97249f --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -0,0 +1,460 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...schedulers import ConsistencyDecoderScheduler +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ...utils.torch_utils import randn_tensor +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..modeling_utils import ModelMixin +from ..unets.unet_2d import UNet2DModel +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class ConsistencyDecoderVAEOutput(BaseOutput): + """ + Output of encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): + r""" + The consistency decoder used with DALL-E 3. + + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE + + >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) + >>> pipe = StableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16 + ... ).to("cuda") + + >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0] + >>> image + ``` + """ + + @register_to_config + def __init__( + self, + scaling_factor: float = 0.18215, + latent_channels: int = 4, + sample_size: int = 32, + encoder_act_fn: str = "silu", + encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + encoder_double_z: bool = True, + encoder_down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + encoder_in_channels: int = 3, + encoder_layers_per_block: int = 2, + encoder_norm_num_groups: int = 32, + encoder_out_channels: int = 4, + decoder_add_attention: bool = False, + decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024), + decoder_down_block_types: Tuple[str, ...] = ( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + decoder_downsample_padding: int = 1, + decoder_in_channels: int = 7, + decoder_layers_per_block: int = 3, + decoder_norm_eps: float = 1e-05, + decoder_norm_num_groups: int = 32, + decoder_num_train_timesteps: int = 1024, + decoder_out_channels: int = 6, + decoder_resnet_time_scale_shift: str = "scale_shift", + decoder_time_embedding_type: str = "learned", + decoder_up_block_types: Tuple[str, ...] = ( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + ): + super().__init__() + self.encoder = Encoder( + act_fn=encoder_act_fn, + block_out_channels=encoder_block_out_channels, + double_z=encoder_double_z, + down_block_types=encoder_down_block_types, + in_channels=encoder_in_channels, + layers_per_block=encoder_layers_per_block, + norm_num_groups=encoder_norm_num_groups, + out_channels=encoder_out_channels, + ) + + self.decoder_unet = UNet2DModel( + add_attention=decoder_add_attention, + block_out_channels=decoder_block_out_channels, + down_block_types=decoder_down_block_types, + downsample_padding=decoder_downsample_padding, + in_channels=decoder_in_channels, + layers_per_block=decoder_layers_per_block, + norm_eps=decoder_norm_eps, + norm_num_groups=decoder_norm_num_groups, + num_train_timesteps=decoder_num_train_timesteps, + out_channels=decoder_out_channels, + resnet_time_scale_shift=decoder_resnet_time_scale_shift, + time_embedding_type=decoder_time_embedding_type, + up_block_types=decoder_up_block_types, + ) + self.decoder_scheduler = ConsistencyDecoderScheduler() + self.register_to_config(block_out_channels=encoder_block_out_channels) + self.register_to_config(force_upcast=False) + self.register_buffer( + "means", + torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], + persistent=False, + ) + self.register_buffer( + "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] + instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a + plain `tuple` is returned. + """ + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return ConsistencyDecoderVAEOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + num_inference_steps: int = 2, + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + """ + Decodes the input latent vector `z` using the consistency decoder VAE model. + + Args: + z (torch.Tensor): The input latent vector. + generator (Optional[torch.Generator]): The random number generator. Default is None. + return_dict (bool): Whether to return the output as a dictionary. Default is True. + num_inference_steps (int): The number of inference steps. Default is 2. + + Returns: + Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output. + + """ + z = (z * self.config.scaling_factor - self.means) / self.stds + + scale_factor = 2 ** (len(self.config.block_out_channels) - 1) + z = F.interpolate(z, mode="nearest", scale_factor=scale_factor) + + batch_size, _, height, width = z.shape + + self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device) + + x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor( + (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device + ) + + for t in self.decoder_scheduler.timesteps: + model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1) + model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :] + prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample + x_t = prev_sample + + x_0 = x_t + + if not return_dict: + return (x_0,) + + return DecoderOutput(sample=x_0) + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] + instead of a plain tuple. + + Returns: + [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] + is returned, otherwise a plain `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return ConsistencyDecoderVAEOutput(latent_dist=posterior) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*, defaults to `None`): + Generator to use for sampling. + + Returns: + [`DecoderOutput`] or `tuple`: + If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, generator=generator).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/src/diffusers/models/autoencoders/vae.py b/diffusers/src/diffusers/models/autoencoders/vae.py new file mode 100755 index 0000000..bb80ce8 --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/vae.py @@ -0,0 +1,982 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from ...utils import BaseOutput, is_torch_version +from ...utils.torch_utils import randn_tensor +from ..activations import get_activation +from ..attention_processor import SpatialNorm +from ..unets.unet_2d_blocks import ( + AutoencoderTinyBlock, + UNetMidBlock2D, + get_down_block, + get_up_block, +) + + +@dataclass +class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.Tensor + commit_loss: Optional[torch.FloatTensor] = None + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + latent_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class UpSample(nn.Module): + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `UpSample` class.""" + x = torch.relu(x) + x = self.deconv(x) + return x + + +class MaskConditionEncoder(nn.Module): + """ + used in AsymmetricAutoencoderKL + """ + + def __init__( + self, + in_ch: int, + out_ch: int = 192, + res_ch: int = 768, + stride: int = 16, + ) -> None: + super().__init__() + + channels = [] + while stride > 1: + stride = stride // 2 + in_ch_ = out_ch * 2 + if out_ch > res_ch: + out_ch = res_ch + if stride == 1: + in_ch_ = res_ch + channels.append((in_ch_, out_ch)) + out_ch *= 2 + + out_channels = [] + for _in_ch, _out_ch in channels: + out_channels.append(_out_ch) + out_channels.append(channels[-1][0]) + + layers = [] + in_ch_ = in_ch + for l in range(len(out_channels)): + out_ch_ = out_channels[l] + if l == 0 or l == 1: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) + else: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) + in_ch_ = out_ch_ + + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: + r"""The forward method of the `MaskConditionEncoder` class.""" + out = {} + for l in range(len(self.layers)): + layer = self.layers[l] + x = layer(x) + out[str(tuple(x.shape))] = x + x = torch.relu(x) + return out + + +class MaskConditionDecoder(nn.Module): + r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's + decoder with a conditioner on the mask and masked image. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # condition encoder + self.condition_encoder = MaskConditionEncoder( + in_ch=out_channels, + out_ch=block_out_channels[0], + res_ch=block_out_channels[-1], + ) + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + z: torch.Tensor, + image: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + latent_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""The forward method of the `MaskConditionDecoder` class.""" + sample = z + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.condition_encoder), + masked_image, + mask, + use_reentrant=False, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.condition_encoder), + masked_image, + mask, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self.condition_encoder(masked_image, mask) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = up_block(sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]: + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q: torch.Tensor = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q: torch.Tensor = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +class EncoderTiny(nn.Module): + r""" + The `EncoderTiny` layer is a simpler version of the `Encoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + act_fn: str, + ): + super().__init__() + + layers = [] + for i, num_block in enumerate(num_blocks): + num_channels = block_out_channels[i] + + if i == 0: + layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) + else: + layers.append( + nn.Conv2d( + num_channels, + num_channels, + kernel_size=3, + padding=1, + stride=2, + bias=False, + ) + ) + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `EncoderTiny` class.""" + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) + + return x + + +class DecoderTiny(nn.Module): + r""" + The `DecoderTiny` layer is a simpler version of the `Decoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + upsampling_scaling_factor (`int`): + The scaling factor to use for upsampling. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + upsample_fn: str, + ): + super().__init__() + + layers = [ + nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), + get_activation(act_fn), + ] + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + if not is_final_block: + layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn)) + + conv_out_channel = num_channels if not is_final_block else out_channels + layers.append( + nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + ) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = torch.tanh(x / 3) * 3 + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + x = self.layers(x) + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) diff --git a/diffusers/src/diffusers/models/autoencoders/vq_model.py b/diffusers/src/diffusers/models/autoencoders/vq_model.py new file mode 100755 index 0000000..ae8a118 --- /dev/null +++ b/diffusers/src/diffusers/models/autoencoders/vq_model.py @@ -0,0 +1,182 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer +from ..modeling_utils import ModelMixin + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The encoded output sample from the last layer of the model. + """ + + latents: torch.Tensor + + +class VQModel(ModelMixin, ConfigMixin): + r""" + A VQ-VAE model for decoding latent representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + norm_type (`str`, *optional*, defaults to `"group"`): + Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, + scaling_factor: float = 0.18215, + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + lookup_from_codebook=False, + force_upcast=False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + mid_block_add_attention=mid_block_add_attention, + ) + + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_type=norm_type, + mid_block_add_attention=mid_block_add_attention, + ) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + @apply_forward_hook + def decode( + self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None + ) -> Union[DecoderOutput, torch.Tensor]: + # also go through quantization layer + if not force_not_quantize: + quant, commit_loss, _ = self.quantize(h) + elif self.config.lookup_from_codebook: + quant = self.quantize.get_codebook_entry(h, shape) + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) + else: + quant = h + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) + + if not return_dict: + return dec, commit_loss + + return DecoderOutput(sample=dec, commit_loss=commit_loss) + + def forward( + self, sample: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]: + r""" + The [`VQModel`] forward method. + + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a + plain `tuple` is returned. + """ + + h = self.encode(sample).latents + dec = self.decode(h) + + if not return_dict: + return dec.sample, dec.commit_loss + return dec diff --git a/diffusers/src/diffusers/models/branch_cogvideox.py b/diffusers/src/diffusers/models/branch_cogvideox.py new file mode 100755 index 0000000..e241a5f --- /dev/null +++ b/diffusers/src/diffusers/models/branch_cogvideox.py @@ -0,0 +1,435 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import PeftAdapterMixin +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..utils.torch_utils import maybe_allow_in_graph +from .attention import Attention, FeedForward +from .attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from .embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from .controlnet import BaseOutput, zero_module +from .modeling_outputs import Transformer2DModelOutput +from .modeling_utils import ModelMixin +from .normalization import AdaLayerNorm, CogVideoXLayerNormZero +from .transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class CogvideoxBranchOutput(BaseOutput): + branch_block_samples: Tuple[torch.Tensor] + + +class CogvideoXBranchModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + wo_text: bool = False, + id_pool_resample_learnable: bool = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels * 2 if in_channels == 16 else in_channels, # in Videopainter it was in_channels * 2 + 1 + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + wo_text=wo_text, + id_pool_resample_learnable=False, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + + # branch_blocks + self.branch_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.branch_blocks.append(zero_module(nn.Linear(inner_dim, inner_dim))) + + self.branch_x_embedder = zero_module(torch.nn.Linear(in_channels, inner_dim)) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @classmethod + def from_transformer( + cls, + transformer, + num_layers: int = 4, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + load_weights_from_transformer=True, + wo_text: bool = False, + ): + config = transformer.config + config["num_layers"] = num_layers + config["attention_head_dim"] = attention_head_dim + config["num_attention_heads"] = num_attention_heads + config["wo_text"] = wo_text + branch = cls(**config) + + if load_weights_from_transformer: + conv_in_condition_weight=torch.zeros_like(branch.patch_embed.proj.weight) + + if config.in_channels == 16: + conv_in_condition_weight[:,:config.in_channels,...]=transformer.patch_embed.proj.weight + conv_in_condition_weight[:,config.in_channels:2*config.in_channels,...]=transformer.patch_embed.proj.weight + elif config.in_channels == 32: + conv_in_condition_weight[:,:config.in_channels//2,...]=transformer.patch_embed.proj.weight[:,:config.in_channels//2,...] + conv_in_condition_weight[:,config.in_channels//2:config.in_channels,...]=transformer.patch_embed.proj.weight[:,:config.in_channels//2,...] + else: + raise ValueError(f"in_channels {config.in_channels} is not supported") + branch.patch_embed.proj.weight.data=torch.nn.Parameter(conv_in_condition_weight) + branch.patch_embed.proj.bias.data=transformer.patch_embed.proj.bias + + branch.embedding_dropout.load_state_dict(transformer.embedding_dropout.state_dict()) + branch.time_proj.load_state_dict(transformer.time_proj.state_dict()) + branch.time_embedding.load_state_dict(transformer.time_embedding.state_dict()) + + branch.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + + + return branch + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + branch_cond: torch.Tensor = None, + branch_mode: torch.Tensor = None, + conditioning_scale: torch.bfloat16 = 1.0, + timestep: Union[int, float, torch.LongTensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + mask_add: Optional[bool] = False, + wo_text: Optional[bool] = False, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + branch_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + branch_mode (`torch.Tensor`): + The mode tensor of shape `(batch_size, 1)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for branch outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + branch_cond = torch.concat([hidden_states, branch_cond], dim=-3) + hidden_states = self.patch_embed(encoder_hidden_states, branch_cond) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + block_samples = () + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + def create_custom_forward_wo_text(module): + def custom_forward(*inputs): + return module.forward_wo_text(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if not wo_text: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_wo_text(block), + hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + if not wo_text: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + else: + hidden_states = block.forward_wo_text( + hidden_states=hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + block_samples = block_samples + (hidden_states,) + # branch block + branch_block_samples = () + for block_sample, branch_block in zip(block_samples, self.branch_blocks): + block_sample = branch_block(block_sample) + branch_block_samples = branch_block_samples + (block_sample,) + + branch_block_samples = [sample * conditioning_scale for sample in branch_block_samples] + branch_block_samples = [sample.to(dtype=hidden_states.dtype) for sample in branch_block_samples] + + branch_block_samples = None if len(branch_block_samples) == 0 else branch_block_samples + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (branch_block_samples, ) + + return CogvideoxBranchOutput( + branch_block_samples=branch_block_samples + ) + diff --git a/diffusers/src/diffusers/models/controlnet.py b/diffusers/src/diffusers/models/controlnet.py new file mode 100755 index 0000000..d3ae966 --- /dev/null +++ b/diffusers/src/diffusers/models/controlnet.py @@ -0,0 +1,870 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders.single_file_model import FromOriginalModelMixin +from ..utils import BaseOutput, logging +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from .unets.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the middle block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + mid_block_type=unet.config.mid_block_type, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + if hasattr(controlnet, "add_embedding"): + controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/diffusers/src/diffusers/models/controlnet_flax.py b/diffusers/src/diffusers/models/controlnet_flax.py new file mode 100755 index 0000000..0540850 --- /dev/null +++ b/diffusers/src/diffusers/models/controlnet_flax.py @@ -0,0 +1,395 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .modeling_flax_utils import FlaxModelMixin +from .unets.unet_2d_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, +) + + +@flax.struct.dataclass +class FlaxControlNetOutput(BaseOutput): + """ + The output of [`FlaxControlNetModel`]. + + Args: + down_block_res_samples (`jnp.ndarray`): + mid_block_res_sample (`jnp.ndarray`): + """ + + down_block_res_samples: jnp.ndarray + mid_block_res_sample: jnp.ndarray + + +class FlaxControlNetConditioningEmbedding(nn.Module): + conditioning_embedding_channels: int + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256) + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.conv_in = nn.Conv( + self.block_out_channels[0], + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + blocks = [] + for i in range(len(self.block_out_channels) - 1): + channel_in = self.block_out_channels[i] + channel_out = self.block_out_channels[i + 1] + conv1 = nn.Conv( + channel_in, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + blocks.append(conv1) + conv2 = nn.Conv( + channel_out, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + blocks.append(conv2) + self.blocks = blocks + + self.conv_out = nn.Conv( + self.conditioning_embedding_channels, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray: + embedding = self.conv_in(conditioning) + embedding = nn.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = nn.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +@flax_register_to_config +class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + A ControlNet model. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods + implemented for all models (such as downloading or saving). + + This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. + + Inherent JAX features such as the following are supported: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): + The tuple of downsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `Tuple[int]`, *optional*): + The number of attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + """ + + sample_size: int = 32 + in_channels: int = 4 + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + only_cross_attention: Union[bool, Tuple[bool, ...]] = False + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: Union[int, Tuple[int, ...]] = 8 + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None + cross_attention_dim: int = 1280 + dropout: float = 0.0 + use_linear_projection: bool = False + dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True + freq_shift: int = 0 + controlnet_conditioning_channel_order: str = "rgb" + conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256) + + def init_weights(self, rng: jax.Array) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8) + controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] + + def setup(self) -> None: + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = self.num_attention_heads or self.attention_head_dim + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=self.conditioning_embedding_out_channels, + ) + + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(self.down_block_types) + + # down + down_blocks = [] + controlnet_down_blocks = [] + + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + num_attention_heads=num_attention_heads[i], + add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + + for _ in range(self.layers_per_block): + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + self.down_blocks = down_blocks + self.controlnet_down_blocks = controlnet_down_blocks + + # mid + mid_block_channel = block_out_channels[-1] + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=mid_block_channel, + dropout=self.dropout, + num_attention_heads=num_attention_heads[-1], + use_linear_projection=self.use_linear_projection, + dtype=self.dtype, + ) + + self.controlnet_mid_block = nn.Conv( + mid_block_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__( + self, + sample: jnp.ndarray, + timesteps: Union[jnp.ndarray, float, int], + encoder_hidden_states: jnp.ndarray, + controlnet_cond: jnp.ndarray, + conditioning_scale: float = 1.0, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]: + r""" + Args: + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor + conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of + a plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise + a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + channel_order = self.controlnet_conditioning_channel_order + if channel_order == "bgr": + controlnet_cond = jnp.flip(controlnet_cond, axis=1) + + # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) + sample = self.conv_in(sample) + + controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample += controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + # 5. contronet blocks + controlnet_down_block_res_samples = () + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return FlaxControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) diff --git a/diffusers/src/diffusers/models/controlnet_flux.py b/diffusers/src/diffusers/models/controlnet_flux.py new file mode 100755 index 0000000..036e565 --- /dev/null +++ b/diffusers/src/diffusers/models/controlnet_flux.py @@ -0,0 +1,517 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import PeftAdapterMixin +from ..models.attention_processor import AttentionProcessor +from ..models.modeling_utils import ModelMixin +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from .controlnet import BaseOutput, zero_module +from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from .modeling_outputs import Transformer2DModelOutput +from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FluxControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + controlnet_single_block_samples: Tuple[torch.Tensor] + + +class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + num_mode: int = None, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(len(self.single_transformer_blocks)): + self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.union = num_mode is not None + if self.union: + self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) + + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer( + cls, + transformer, + num_layers: int = 4, + num_single_layers: int = 10, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + load_weights_from_transformer=True, + ): + config = transformer.config + config["num_layers"] = num_layers + config["num_single_layers"] = num_single_layers + config["attention_head_dim"] = attention_head_dim + config["num_attention_heads"] = num_attention_heads + + controlnet = cls(**config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + controlnet.single_transformer_blocks.load_state_dict( + transformer.single_transformer_blocks.state_dict(), strict=False + ) + + controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) + + return controlnet + + def forward( + self, + hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + controlnet_mode: torch.Tensor = None, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + controlnet_mode (`torch.Tensor`): + The mode tensor of shape `(batch_size, 1)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + # add + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if self.union: + # union mode + if controlnet_mode is None: + raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") + # union mode emb + controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) + encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) + txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + block_samples = () + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + block_samples = block_samples + (hidden_states,) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + single_block_samples = () + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + + # controlnet block + controlnet_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + + controlnet_single_block_samples = () + for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): + single_block_sample = controlnet_block(single_block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) + + # scaling + controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] + controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] + + controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples + controlnet_single_block_samples = ( + None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_samples, controlnet_single_block_samples) + + return FluxControlNetOutput( + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + ) + + +class FluxMultiControlNetModel(ModelMixin): + r""" + `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel + + This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be + compatible with `FluxControlNetModel`. + + Args: + controlnets (`List[FluxControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `FluxControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + controlnet_mode: List[torch.tensor], + conditioning_scale: List[float], + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[FluxControlNetOutput, Tuple]: + # ControlNet-Union with multiple conditions + # only load one ControlNet for saving memories + if len(self.nets) == 1 and self.nets[0].union: + controlnet = self.nets[0] + + for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + # Regular Multi-ControlNets + # load all ControlNets into memories + else: + for i, (image, mode, scale, controlnet) in enumerate( + zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) + ): + block_samples, single_block_samples = controlnet( + hidden_states=hidden_states, + controlnet_cond=image, + controlnet_mode=mode[:, None], + conditioning_scale=scale, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_projections, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + img_ids=img_ids, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + control_single_block_samples = single_block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] + + return control_block_samples, control_single_block_samples diff --git a/diffusers/src/diffusers/models/controlnet_hunyuan.py b/diffusers/src/diffusers/models/controlnet_hunyuan.py new file mode 100755 index 0000000..4277d81 --- /dev/null +++ b/diffusers/src/diffusers/models/controlnet_hunyuan.py @@ -0,0 +1,401 @@ +# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging +from .attention_processor import AttentionProcessor +from .controlnet import BaseOutput, Tuple, zero_module +from .embeddings import ( + HunyuanCombinedTimestepTextSizeStyleEmbedding, + PatchEmbed, + PixArtAlphaTextProjection, +) +from .modeling_utils import ModelMixin +from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class HunyuanControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + + +class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "gelu-approximate", + sample_size=32, + hidden_size=1152, + transformer_num_layers: int = 40, + mlp_ratio: float = 4.0, + cross_attention_dim: int = 1024, + cross_attention_dim_t5: int = 2048, + pooled_projection_dim: int = 1024, + text_len: int = 77, + text_len_t5: int = 256, + use_style_cond_and_image_meta_size: bool = True, + ): + super().__init__() + self.num_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + + self.text_embedder = PixArtAlphaTextProjection( + in_features=cross_attention_dim_t5, + hidden_size=cross_attention_dim_t5 * 4, + out_features=cross_attention_dim, + act_fn="silu_fp32", + ) + + self.text_embedding_padding = nn.Parameter( + torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) + ) + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + in_channels=in_channels, + embed_dim=hidden_size, + patch_size=patch_size, + pos_embed_type=None, + ) + + self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( + hidden_size, + pooled_projection_dim=pooled_projection_dim, + seq_len=text_len_t5, + cross_attention_dim=cross_attention_dim_t5, + use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + + # HunyuanDiT Blocks + self.blocks = nn.ModuleList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=False, # always False as it is the first half of the model + ) + for layer in range(transformer_num_layers // 2 - 1) + ] + ) + self.input_block = zero_module(nn.Linear(hidden_size, hidden_size)) + for _ in range(len(self.blocks)): + controlnet_block = nn.Linear(hidden_size, hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the + corresponding cross attention processor. This is strongly recommended when setting trainable attention + processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + @classmethod + def from_transformer( + cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True + ): + config = transformer.config + activation_fn = config.activation_fn + attention_head_dim = config.attention_head_dim + cross_attention_dim = config.cross_attention_dim + cross_attention_dim_t5 = config.cross_attention_dim_t5 + hidden_size = config.hidden_size + in_channels = config.in_channels + mlp_ratio = config.mlp_ratio + num_attention_heads = config.num_attention_heads + patch_size = config.patch_size + sample_size = config.sample_size + text_len = config.text_len + text_len_t5 = config.text_len_t5 + + conditioning_channels = conditioning_channels + transformer_num_layers = transformer_num_layers or config.transformer_num_layers + + controlnet = cls( + conditioning_channels=conditioning_channels, + transformer_num_layers=transformer_num_layers, + activation_fn=activation_fn, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + cross_attention_dim_t5=cross_attention_dim_t5, + hidden_size=hidden_size, + in_channels=in_channels, + mlp_ratio=mlp_ratio, + num_attention_heads=num_attention_heads, + patch_size=patch_size, + sample_size=sample_size, + text_len=text_len, + text_len_t5=text_len_t5, + ) + if load_weights_from_transformer: + key = controlnet.load_state_dict(transformer.state_dict(), strict=False) + logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}") + return controlnet + + def forward( + self, + hidden_states, + timestep, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + return_dict=True, + ): + """ + The [`HunyuanDiT2DControlNetModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + controlnet_cond ( `torch.Tensor` ): + The conditioning input to ControlNet. + conditioning_scale ( `float` ): + Indicate the conditioning scale. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of `BertModel`. + text_embedding_mask: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of `BertModel`. + encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. + text_embedding_mask_t5: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of T5 Text Encoder. + image_meta_size (torch.Tensor): + Conditional embedding indicate the image sizes + style: torch.Tensor: + Conditional embedding indicate the style + image_rotary_emb (`torch.Tensor`): + The image rotary embeddings to apply on query and key tensors during attention calculation. + return_dict: bool + Whether to return a dictionary. + """ + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # b,c,H,W -> b, N, C + + # 2. pre-process + hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond)) + + temb = self.time_extra_emb( + timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype + ) # [B, D] + + # text projection + batch_size, sequence_length, _ = encoder_hidden_states_t5.shape + encoder_hidden_states_t5 = self.text_embedder( + encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) + ) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) + + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) + text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) + text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() + + encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + + block_res_samples = () + for layer, block in enumerate(self.blocks): + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) # (N, L, D) + + block_res_samples = block_res_samples + (hidden_states,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + # 6. scaling + controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] + + if not return_dict: + return (controlnet_block_res_samples,) + + return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) + + +class HunyuanDiT2DMultiControlNetModel(ModelMixin): + r""" + `HunyuanDiT2DMultiControlNetModel` wrapper class for Multi-HunyuanDiT2DControlNetModel + + This module is a wrapper for multiple instances of the `HunyuanDiT2DControlNetModel`. The `forward()` API is + designed to be compatible with `HunyuanDiT2DControlNetModel`. + + Args: + controlnets (`List[HunyuanDiT2DControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `HunyuanDiT2DControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states, + timestep, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + return_dict=True, + ): + """ + The [`HunyuanDiT2DControlNetModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + controlnet_cond ( `torch.Tensor` ): + The conditioning input to ControlNet. + conditioning_scale ( `float` ): + Indicate the conditioning scale. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of `BertModel`. + text_embedding_mask: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of `BertModel`. + encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. + text_embedding_mask_t5: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of T5 Text Encoder. + image_meta_size (torch.Tensor): + Conditional embedding indicate the image sizes + style: torch.Tensor: + Conditional embedding indicate the style + image_rotary_emb (`torch.Tensor`): + The image rotary embeddings to apply on query and key tensors during attention calculation. + return_dict: bool + Whether to return a dictionary. + """ + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + block_samples = controlnet( + hidden_states=hidden_states, + timestep=timestep, + controlnet_cond=image, + conditioning_scale=scale, + encoder_hidden_states=encoder_hidden_states, + text_embedding_mask=text_embedding_mask, + encoder_hidden_states_t5=encoder_hidden_states_t5, + text_embedding_mask_t5=text_embedding_mask_t5, + image_meta_size=image_meta_size, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) + ] + control_block_samples = (control_block_samples,) + + return control_block_samples diff --git a/diffusers/src/diffusers/models/controlnet_sd3.py b/diffusers/src/diffusers/models/controlnet_sd3.py new file mode 100755 index 0000000..f19571d --- /dev/null +++ b/diffusers/src/diffusers/models/controlnet_sd3.py @@ -0,0 +1,422 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalModelMixin, PeftAdapterMixin +from ..models.attention import JointTransformerBlock +from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 +from ..models.modeling_outputs import Transformer2DModelOutput +from ..models.modeling_utils import ModelMixin +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from .controlnet import BaseOutput, zero_module +from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SD3ControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + + +class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + extra_conditioning_channels: int = 0, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + ) + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + context_pre_only=False, + ) + for i in range(num_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + controlnet_block = nn.Linear(self.inner_dim, self.inner_dim) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + pos_embed_input = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels + extra_conditioning_channels, + embed_dim=self.inner_dim, + pos_embed_type=None, + ) + self.pos_embed_input = zero_module(pos_embed_input) + + self.gradient_checkpointing = False + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer( + cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True + ): + config = transformer.config + config["num_layers"] = num_layers or config.num_layers + config["extra_conditioning_channels"] = num_extra_conditioning_channels + controlnet = cls(**config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + + controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) + + return controlnet + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # add + hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) + + block_res_samples = () + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + block_res_samples = block_res_samples + (hidden_states,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + # 6. scaling + controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_res_samples,) + + return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) + + +class SD3MultiControlNetModel(ModelMixin): + r""" + `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet + + This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be + compatible with `SD3ControlNetModel`. + + Args: + controlnets (`List[SD3ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `SD3ControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + pooled_projections: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[SD3ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + block_samples = controlnet( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + pooled_projections=pooled_projections, + controlnet_cond=image, + conditioning_scale=scale, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) + ] + control_block_samples = (tuple(control_block_samples),) + + return control_block_samples diff --git a/diffusers/src/diffusers/models/controlnet_sparsectrl.py b/diffusers/src/diffusers/models/controlnet_sparsectrl.py new file mode 100755 index 0000000..fa37e1f --- /dev/null +++ b/diffusers/src/diffusers/models/controlnet_sparsectrl.py @@ -0,0 +1,788 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalModelMixin +from ..utils import BaseOutput, logging +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn +from .unets.unet_2d_condition import UNet2DConditionModel +from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SparseControlNetOutput(BaseOutput): + """ + The output of [`SparseControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the middle block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class SparseControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning: torch.Tensor) -> torch.Tensor: + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + return embedding + + +class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """ + A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion + Models](https://arxiv.org/abs/2311.16933). + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + conditioning_channels (`int`, defaults to 4): + The number of input channels in the controlnet conditional embedding module. If + `concat_condition_embedding` is True, the value provided here is incremented by 1. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer layers to use in each layer in the middle block. + attention_head_dim (`int` or `Tuple[int]`, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `Tuple[int]`, *optional*): + The number of heads to use for multi-head attention. + use_linear_projection (`bool`, defaults to `False`): + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter + controlnet_conditioning_channel_order (`str`, defaults to `rgb`): + motion_max_seq_length (`int`, defaults to `32`): + The maximum sequence length to use in the motion module. + motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`): + The number of heads to use in each attention layer of the motion module. + concat_conditioning_mask (`bool`, defaults to `True`): + use_simplified_condition_embedding (`bool`, defaults to `True`): + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 768, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, + temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + controlnet_conditioning_channel_order: str = "rgb", + motion_max_seq_length: int = 32, + motion_num_attention_heads: int = 8, + concat_conditioning_mask: bool = True, + use_simplified_condition_embedding: bool = True, + ): + super().__init__() + self.use_simplified_condition_embedding = use_simplified_condition_embedding + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + if concat_conditioning_mask: + conditioning_channels = conditioning_channels + 1 + + self.concat_conditioning_mask = concat_conditioning_mask + + # control net conditioning embedding + if use_simplified_condition_embedding: + self.controlnet_cond_embedding = zero_module( + nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + ) + else: + self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=0, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=0, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + add_downsample=not is_final_block, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, + ) + else: + raise ValueError( + "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) + + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channels = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if transformer_layers_per_mid_block is None: + transformer_layers_per_mid_block = ( + transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 + ) + + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=mid_block_channels, + temb_channels=time_embed_dim, + dropout=0, + num_layers=1, + transformer_layers_per_block=transformer_layers_per_mid_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + num_attention_heads=num_attention_heads[-1], + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type="default", + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ) -> "SparseControlNetModel": + r""" + Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also + copied where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + down_block_types = unet.config.down_block_types + + for i in range(len(down_block_types)): + if "CrossAttn" in down_block_types[i]: + down_block_types[i] = "CrossAttnDownBlockMotion" + elif "Down" in down_block_types[i]: + down_block_types[i] = "DownBlockMotion" + else: + raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block") + + controlnet = cls( + in_channels=unet.config.in_channels, + conditioning_channels=conditioning_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False) + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) + + return controlnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + conditioning_mask: Optional[torch.Tensor] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`SparseControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape + sample = torch.zeros_like(sample) + + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(sample_num_frames, dim=0) + + # 2. pre-process + batch_size, channels, num_frames, height, width = sample.shape + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + sample = self.conv_in(sample) + + batch_frames, channels, height, width = sample.shape + sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) + + if self.concat_conditioning_mask: + controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) + + batch_size, channels, num_frames, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width + ) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + batch_frames, channels, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width) + + sample = sample + controlnet_cond + + batch_size, num_frames, channels, height, width = sample.shape + sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return SparseControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +# Copied from diffusers.models.controlnet.zero_module +def zero_module(module: nn.Module) -> nn.Module: + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/diffusers/src/diffusers/models/controlnet_xs.py b/diffusers/src/diffusers/models/controlnet_xs.py new file mode 100755 index 0000000..f676a70 --- /dev/null +++ b/diffusers/src/diffusers/models/controlnet_xs.py @@ -0,0 +1,1945 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from math import gcd +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, is_torch_version, logging +from ..utils.torch_utils import apply_freeu +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from .controlnet import ControlNetConditioningEmbedding +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + Downsample2D, + ResnetBlock2D, + Transformer2DModel, + UNetMidBlock2DCrossAttn, + Upsample2D, +) +from .unets.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetXSOutput(BaseOutput): + """ + The output of [`UNetControlNetXSModel`]. + + Args: + sample (`Tensor` of shape `(batch_size, num_channels, height, width)`): + The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base + model output, but is already the final output. + """ + + sample: Tensor = None + + +class DownBlockControlNetXSAdapter(nn.Module): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnDownBlock2D`""" + + def __init__( + self, + resnets: nn.ModuleList, + base_to_ctrl: nn.ModuleList, + ctrl_to_base: nn.ModuleList, + attentions: Optional[nn.ModuleList] = None, + downsampler: Optional[nn.Conv2d] = None, + ): + super().__init__() + self.resnets = resnets + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + self.attentions = attentions + self.downsamplers = downsampler + + +class MidBlockControlNetXSAdapter(nn.Module): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnMidBlock2D`""" + + def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList): + super().__init__() + self.midblock = midblock + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + + +class UpBlockControlNetXSAdapter(nn.Module): + """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" + + def __init__(self, ctrl_to_base: nn.ModuleList): + super().__init__() + self.ctrl_to_base = ctrl_to_base + + +def get_down_block_adapter( + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + max_norm_num_groups: Optional[int] = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, + use_linear_projection: Optional[bool] = True, +): + num_layers = 2 # only support sd + sdxl + + resnets = [] + attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + eps=1e-5, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + ctrl_out_channels // num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + ) + ) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + downsamplers = None + + down_block_components = DownBlockControlNetXSAdapter( + resnets=nn.ModuleList(resnets), + base_to_ctrl=nn.ModuleList(base_to_ctrl), + ctrl_to_base=nn.ModuleList(ctrl_to_base), + ) + + if has_crossattn: + down_block_components.attentions = nn.ModuleList(attentions) + if downsamplers is not None: + down_block_components.downsamplers = downsamplers + + return down_block_components + + +def get_mid_block_adapter( + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + max_norm_num_groups: Optional[int] = 32, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, + use_linear_projection: bool = True, +): + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl = make_zero_conv(base_channels, base_channels) + + midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) + + +def get_up_block_adapter( + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], +): + ctrl_to_base = [] + num_layers = 3 # only support sd + sdxl + for i in range(num_layers): + resnet_in_channels = prev_output_channel if i == 0 else out_channels + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base)) + + +class ControlNetXSAdapter(ModelMixin, ConfigMixin): + r""" + A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a + `UNet2DConditionModel` base model). + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's + default parameters are compatible with StableDiffusion. + + Parameters: + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channels for each block in the `controlnet_cond_embedding` layer. + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. + learn_time_embedding (`bool`, defaults to `False`): + Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time + embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base + model's time embedding. + num_attention_heads (`list[int]`, defaults to `[4]`): + The number of attention heads. + block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): + The tuple of output channels for each block. + base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): + The tuple of output channels for each block in the base unet. + cross_attention_dim (`int`, defaults to 1024): + The dimension of the cross attention features. + down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): + The tuple of downsample blocks to use. + sample_size (`int`, defaults to 96): + Height and width of input/output sample. + transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + upcast_attention (`bool`, defaults to `True`): + Whether the attention computation should always be upcasted. + max_norm_num_groups (`int`, defaults to 32): + Maximum number of groups in group normal. The actual number will be the largest divisor of the respective + channels, that is <= max_norm_num_groups. + """ + + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + time_embedding_mix: float = 1.0, + learn_time_embedding: bool = False, + num_attention_heads: Union[int, Tuple[int]] = 4, + block_out_channels: Tuple[int] = (4, 8, 16, 16), + base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + cross_attention_dim: int = 1024, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + sample_size: Optional[int] = 96, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + upcast_attention: bool = True, + max_norm_num_groups: int = 32, + use_linear_projection: bool = True, + ): + super().__init__() + + time_embedding_input_dim = base_block_out_channels[0] + time_embedding_dim = base_block_out_channels[0] * 4 + + # Check inputs + if conditioning_channel_order not in ["rgb", "bgr"]: + raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim` + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + + if len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # time + if learn_time_embedding: + self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) + else: + self.time_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.up_connections = nn.ModuleList([]) + + # input + self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) + + # down + base_out_channels = base_block_out_channels[0] + ctrl_out_channels = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = base_block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i == len(down_block_types) - 1 + + self.down_blocks.append( + get_down_block_adapter( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embedding_dim, + max_norm_num_groups=max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ) + ) + + # mid + self.mid_block = get_mid_block_adapter( + base_channels=base_block_out_channels[-1], + ctrl_channels=block_out_channels[-1], + temb_channels=time_embedding_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ) + + # up + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [block_out_channels[0]] + for i, out_channels in enumerate(block_out_channels): + number_of_subblocks = ( + 3 if i < len(block_out_channels) - 1 else 2 + ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler + ctrl_skip_channels.extend([out_channels] * number_of_subblocks) + + reversed_base_block_out_channels = list(reversed(base_block_out_channels)) + + base_out_channels = reversed_base_block_out_channels[0] + for i in range(len(down_block_types)): + prev_base_output_channel = base_out_channels + base_out_channels = reversed_base_block_out_channels[i] + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + self.up_connections.append( + get_up_block_adapter( + out_channels=base_out_channels, + prev_output_channel=prev_base_output_channel, + ctrl_skip_channels=ctrl_skip_channels_, + ) + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + size_ratio: Optional[float] = None, + block_out_channels: Optional[List[int]] = None, + num_attention_heads: Optional[List[int]] = None, + learn_time_embedding: bool = False, + time_embedding_mix: int = 1.0, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + r""" + Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it. + size_ratio (float, *optional*, defaults to `None`): + When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this + or `block_out_channels` must be given. + block_out_channels (`List[int]`, *optional*, defaults to `None`): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + num_attention_heads (`List[int]`, *optional*, defaults to `None`): + The dimension of the attention heads. The naming seems a bit confusing and it is, see + https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + learn_time_embedding (`bool`, defaults to `False`): + Whether the `ControlNetXSAdapter` should learn a time embedding. + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + """ + + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." + ) + + # Create model + block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels] + if num_attention_heads is None: + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + num_attention_heads = unet.config.attention_head_dim + + model = cls( + conditioning_channels=conditioning_channels, + conditioning_channel_order=conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + time_embedding_mix=time_embedding_mix, + learn_time_embedding=learn_time_embedding, + num_attention_heads=num_attention_heads, + block_out_channels=block_out_channels, + base_block_out_channels=unet.config.block_out_channels, + cross_attention_dim=unet.config.cross_attention_dim, + down_block_types=unet.config.down_block_types, + sample_size=unet.config.sample_size, + transformer_layers_per_block=unet.config.transformer_layers_per_block, + upcast_attention=unet.config.upcast_attention, + max_norm_num_groups=unet.config.norm_num_groups, + use_linear_projection=unet.config.use_linear_projection, + ) + + # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def forward(self, *args, **kwargs): + raise ValueError( + "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel." + ) + + +class UNetControlNetXSModel(ModelMixin, ConfigMixin): + r""" + A UNet fused with a ControlNet-XS adapter model + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are + compatible with StableDiffusion. + + It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in + `ControlNetXSAdapter` . See their documentation for details. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # unet configs + sample_size: Optional[int] = 96, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: Union[int, Tuple[int]] = 8, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + upcast_attention: bool = True, + use_linear_projection: bool = True, + time_cond_proj_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, + # additional controlnet configs + time_embedding_mix: float = 1.0, + ctrl_conditioning_channels: int = 3, + ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + ctrl_conditioning_channel_order: str = "rgb", + ctrl_learn_time_embedding: bool = False, + ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), + ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, + ctrl_max_norm_num_groups: int = 32, + ): + super().__init__() + + if time_embedding_mix < 0 or time_embedding_mix > 1: + raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") + if time_embedding_mix < 1 and not ctrl_learn_time_embedding: + raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`") + + if addition_embed_type is not None and addition_embed_type != "text_time": + raise ValueError( + "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." + ) + + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + if not isinstance(ctrl_num_attention_heads, (list, tuple)): + ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types) + + base_num_attention_heads = num_attention_heads + + self.in_channels = 4 + + # # Input + self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=ctrl_block_out_channels[0], + block_out_channels=ctrl_conditioning_embedding_out_channels, + conditioning_channels=ctrl_conditioning_channels, + ) + self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1) + self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) + + # # Time + time_embed_input_dim = block_out_channels[0] + time_embed_dim = block_out_channels[0] * 4 + + self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_time_embedding = TimestepEmbedding( + time_embed_input_dim, + time_embed_dim, + cond_proj_dim=time_cond_proj_dim, + ) + if ctrl_learn_time_embedding: + self.ctrl_time_embedding = TimestepEmbedding( + in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim + ) + else: + self.ctrl_time_embedding = None + + if addition_embed_type is None: + self.base_add_time_proj = None + self.base_add_embedding = None + else: + self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + # # Create down blocks + down_blocks = [] + base_out_channels = block_out_channels[0] + ctrl_out_channels = ctrl_block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = ctrl_block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i == len(down_block_types) - 1 + + down_blocks.append( + ControlNetXSCrossAttnDownBlock2D( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + base_num_attention_heads=base_num_attention_heads[i], + ctrl_num_attention_heads=ctrl_num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ) + ) + + # # Create mid block + self.mid_block = ControlNetXSCrossAttnMidBlock2D( + base_channels=block_out_channels[-1], + ctrl_channels=ctrl_block_out_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, + transformer_layers_per_block=transformer_layers_per_block[-1], + base_num_attention_heads=base_num_attention_heads[-1], + ctrl_num_attention_heads=ctrl_num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ) + + # # Create up blocks + up_blocks = [] + rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + rev_num_attention_heads = list(reversed(base_num_attention_heads)) + rev_cross_attention_dim = list(reversed(cross_attention_dim)) + + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [ctrl_block_out_channels[0]] + for i, out_channels in enumerate(ctrl_block_out_channels): + number_of_subblocks = ( + 3 if i < len(ctrl_block_out_channels) - 1 else 2 + ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler + ctrl_skip_channels.extend([out_channels] * number_of_subblocks) + + reversed_block_out_channels = list(reversed(block_out_channels)) + + out_channels = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = out_channels + out_channels = reversed_block_out_channels[i] + in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + has_crossattn = "CrossAttn" in up_block_type + is_final_block = i == len(block_out_channels) - 1 + + up_blocks.append( + ControlNetXSCrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + ctrl_skip_channels=ctrl_skip_channels_, + temb_channels=time_embed_dim, + resolution_idx=i, + has_crossattn=has_crossattn, + transformer_layers_per_block=rev_transformer_layers_per_block[i], + num_attention_heads=rev_num_attention_heads[i], + cross_attention_dim=rev_cross_attention_dim[i], + add_upsample=not is_final_block, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) + self.base_conv_act = nn.SiLU() + self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet: Optional[ControlNetXSAdapter] = None, + size_ratio: Optional[float] = None, + ctrl_block_out_channels: Optional[List[float]] = None, + time_embedding_mix: Optional[float] = None, + ctrl_optional_kwargs: Optional[Dict] = None, + ): + r""" + Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`] + . + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. + controlnet (`ControlNetXSAdapter`): + The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS + adapter will be created. + size_ratio (float, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. + ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, + where this parameter is called `block_out_channels`. + time_embedding_mix (`float`, *optional*, defaults to None): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. + ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): + Passed to the `init` of the new controlent if no controlent was given. + """ + if controlnet is None: + controlnet = ControlNetXSAdapter.from_unet( + unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs + ) + else: + if any( + o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) + ): + raise ValueError( + "When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs." + ) + + # # get params + params_for_unet = [ + "sample_size", + "down_block_types", + "up_block_types", + "block_out_channels", + "norm_num_groups", + "cross_attention_dim", + "transformer_layers_per_block", + "addition_embed_type", + "addition_time_embed_dim", + "upcast_attention", + "use_linear_projection", + "time_cond_proj_dim", + "projection_class_embeddings_input_dim", + ] + params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet} + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + params_for_unet["num_attention_heads"] = unet.config.attention_head_dim + + params_for_controlnet = [ + "conditioning_channels", + "conditioning_embedding_out_channels", + "conditioning_channel_order", + "learn_time_embedding", + "block_out_channels", + "num_attention_heads", + "max_norm_num_groups", + ] + params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} + params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix + + # # create model + model = cls.from_config({**params_for_unet, **params_for_controlnet}) + + # # load weights + # from unet + modules_from_unet = [ + "time_embedding", + "conv_in", + "conv_norm_out", + "conv_out", + ] + for m in modules_from_unet: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + optional_modules_from_unet = [ + "add_time_proj", + "add_embedding", + ] + for m in optional_modules_from_unet: + if hasattr(unet, m) and getattr(unet, m) is not None: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + # from controlnet + model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict()) + model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict()) + if controlnet.time_embedding is not None: + model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict()) + model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict()) + + # from both + model.down_blocks = nn.ModuleList( + ControlNetXSCrossAttnDownBlock2D.from_modules(b, c) + for b, c in zip(unet.down_blocks, controlnet.down_blocks) + ) + model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block) + model.up_blocks = nn.ModuleList( + ControlNetXSCrossAttnUpBlock2D.from_modules(b, c) + for b, c in zip(unet.up_blocks, controlnet.up_connections) + ) + + # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def freeze_unet_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Freeze everything + for param in self.parameters(): + param.requires_grad = True + + # Unfreeze ControlNetXSAdapter + base_parts = [ + "base_time_proj", + "base_time_embedding", + "base_add_time_proj", + "base_add_embedding", + "base_conv_in", + "base_conv_norm_out", + "base_conv_act", + "base_conv_out", + ] + base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + for d in self.down_blocks: + d.freeze_base_params() + self.mid_block.freeze_base_params() + for u in self.up_blocks: + u.freeze_base_params() + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + sample: Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: Optional[torch.Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + return_dict: bool = True, + apply_control: bool = True, + ) -> Union[ControlNetXSOutput, Tuple]: + """ + The [`ControlNetXSModel`] forward method. + + Args: + sample (`Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + How much the control model affects the base model outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + apply_control (`bool`, defaults to `True`): + If `False`, the input is run only through the base model. + + Returns: + [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + + # check channel order + if self.config.ctrl_conditioning_channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.base_time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + if self.config.ctrl_learn_time_embedding and apply_control: + ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) + base_temb = self.base_time_embedding(t_emb, timestep_cond) + interpolation_param = self.config.time_embedding_mix**0.3 + + temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) + else: + temb = self.base_time_embedding(t_emb) + + # added time & text embeddings + aug_emb = None + + if self.config.addition_embed_type is None: + pass + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.base_add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(temb.dtype) + aug_emb = self.base_add_embedding(add_embeds) + else: + raise ValueError( + f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported." + ) + + temb = temb + aug_emb if aug_emb is not None else temb + + # text embeddings + cemb = encoder_hidden_states + + # Preparation + h_ctrl = h_base = sample + hs_base, hs_ctrl = [], [] + + # Cross Control + guided_hint = self.controlnet_cond_embedding(controlnet_cond) + + # 1 - conv in & down + + h_base = self.base_conv_in(h_base) + h_ctrl = self.ctrl_conv_in(h_ctrl) + if guided_hint is not None: + h_ctrl += guided_hint + if apply_control: + h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base + + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + for down in self.down_blocks: + h_base, h_ctrl, residual_hb, residual_hc = down( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + hs_base.extend(residual_hb) + hs_ctrl.extend(residual_hc) + + # 2 - mid + h_base, h_ctrl = self.mid_block( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 3 - up + for up in self.up_blocks: + n_resnets = len(up.resnets) + skips_hb = hs_base[-n_resnets:] + skips_hc = hs_ctrl[-n_resnets:] + hs_base = hs_base[:-n_resnets] + hs_ctrl = hs_ctrl[:-n_resnets] + h_base = up( + hidden_states=h_base, + res_hidden_states_tuple_base=skips_hb, + res_hidden_states_tuple_ctrl=skips_hc, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 4 - conv out + h_base = self.base_conv_norm_out(h_base) + h_base = self.base_conv_act(h_base) + h_base = self.base_conv_out(h_base) + + if not return_dict: + return (h_base,) + + return ControlNetXSOutput(sample=h_base) + + +class ControlNetXSCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, + use_linear_projection: Optional[bool] = True, + ): + super().__init__() + base_resnets = [] + base_attentions = [] + ctrl_resnets = [] + ctrl_attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + num_layers = 2 # only support sd + sdxl + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + base_resnets.append( + ResnetBlock2D( + in_channels=base_in_channels, + out_channels=base_out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + ctrl_resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor( + ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups + ), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), + eps=1e-5, + ) + ) + + if has_crossattn: + base_attentions.append( + Transformer2DModel( + base_num_attention_heads, + base_out_channels // base_num_attention_heads, + in_channels=base_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + ctrl_attentions.append( + Transformer2DModel( + ctrl_num_attention_heads, + ctrl_out_channels // ctrl_num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), + ) + ) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + self.base_downsamplers = Downsample2D( + base_out_channels, use_conv=True, out_channels=base_out_channels, name="op" + ) + self.ctrl_downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + self.base_downsamplers = None + self.ctrl_downsamplers = None + + self.base_resnets = nn.ModuleList(base_resnets) + self.ctrl_resnets = nn.ModuleList(ctrl_resnets) + self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers + self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers + self.base_to_ctrl = nn.ModuleList(base_to_ctrl) + self.ctrl_to_base = nn.ModuleList(ctrl_to_base) + + self.gradient_checkpointing = False + + @classmethod + def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter): + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + base_in_channels = base_downblock.resnets[0].in_channels + base_out_channels = base_downblock.resnets[0].out_channels + ctrl_in_channels = ( + ctrl_downblock.resnets[0].in_channels - base_in_channels + ) # base channels are concatted to ctrl channels in init + ctrl_out_channels = ctrl_downblock.resnets[0].out_channels + temb_channels = base_downblock.resnets[0].time_emb_proj.in_features + num_groups = base_downblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups + if hasattr(base_downblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) + base_num_attention_heads = get_first_cross_attention(base_downblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads + cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_downblock).upcast_attention + use_linear_projection = base_downblock.attentions[0].use_linear_projection + else: + has_crossattn = False + transformer_layers_per_block = None + base_num_attention_heads = None + ctrl_num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + use_linear_projection = None + add_downsample = base_downblock.downsamplers is not None + + # create model + model = cls( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_downsample=add_downsample, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ) + + # # load weights + model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) + model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict()) + if has_crossattn: + model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) + model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict()) + if add_downsample: + model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict()) + model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict()) + model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict()) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.base_resnets] + if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.base_attentions) + if self.base_downsamplers is not None: + base_parts.append(self.base_downsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states_base: Tensor, + temb: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + hidden_states_ctrl: Optional[Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + attention_mask: Optional[Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[Tensor] = None, + apply_control: bool = True, + ) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + base_output_states = () + ctrl_output_states = () + + base_blocks = list(zip(self.base_resnets, self.base_attentions)) + ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( + base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base + ): + # concat base -> ctrl + if apply_control: + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + + # apply base subblock + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + h_base = torch.utils.checkpoint.checkpoint( + create_custom_forward(b_res), + h_base, + temb, + **ckpt_kwargs, + ) + else: + h_base = b_res(h_base, temb) + + if b_attn is not None: + h_base = b_attn( + h_base, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply ctrl subblock + if apply_control: + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + h_ctrl = torch.utils.checkpoint.checkpoint( + create_custom_forward(c_res), + h_ctrl, + temb, + **ckpt_kwargs, + ) + else: + h_ctrl = c_res(h_ctrl, temb) + if c_attn is not None: + h_ctrl = c_attn( + h_ctrl, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # add ctrl -> base + if apply_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler + b2c = self.base_to_ctrl[-1] + c2b = self.ctrl_to_base[-1] + + # concat base -> ctrl + if apply_control: + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + # apply base subblock + h_base = self.base_downsamplers(h_base) + # apply ctrl subblock + if apply_control: + h_ctrl = self.ctrl_downsamplers(h_ctrl) + # add ctrl -> base + if apply_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + return h_base, h_ctrl, base_output_states, ctrl_output_states + + +class ControlNetXSCrossAttnMidBlock2D(nn.Module): + def __init__( + self, + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, + transformer_layers_per_block: int = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, + use_linear_projection: Optional[bool] = True, + ): + super().__init__() + + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + self.base_to_ctrl = make_zero_conv(base_channels, base_channels) + + self.base_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=base_channels, + temb_channels=temb_channels, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=base_num_attention_heads, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + self.ctrl_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor( + gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups + ), + cross_attention_dim=cross_attention_dim, + num_attention_heads=ctrl_num_attention_heads, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + self.gradient_checkpointing = False + + @classmethod + def from_modules( + cls, + base_midblock: UNetMidBlock2DCrossAttn, + ctrl_midblock: MidBlockControlNetXSAdapter, + ): + base_to_ctrl = ctrl_midblock.base_to_ctrl + ctrl_to_base = ctrl_midblock.ctrl_to_base + ctrl_midblock = ctrl_midblock.midblock + + # get params + def get_first_cross_attention(midblock): + return midblock.attentions[0].transformer_blocks[0].attn2 + + base_channels = ctrl_to_base.out_channels + ctrl_channels = ctrl_to_base.in_channels + transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) + temb_channels = base_midblock.resnets[0].time_emb_proj.in_features + num_groups = base_midblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups + base_num_attention_heads = get_first_cross_attention(base_midblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads + cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_midblock).upcast_attention + use_linear_projection = base_midblock.attentions[0].use_linear_projection + + # create model + model = cls( + base_channels=base_channels, + ctrl_channels=ctrl_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ) + + # load weights + model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict()) + model.base_midblock.load_state_dict(base_midblock.state_dict()) + model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict()) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + for param in self.base_midblock.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states_base: Tensor, + temb: Tensor, + encoder_hidden_states: Tensor, + hidden_states_ctrl: Optional[Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[Tensor] = None, + encoder_attention_mask: Optional[Tensor] = None, + apply_control: bool = True, + ) -> Tuple[Tensor, Tensor]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + joint_args = { + "temb": temb, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + "cross_attention_kwargs": cross_attention_kwargs, + "encoder_attention_mask": encoder_attention_mask, + } + + if apply_control: + h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl + h_base = self.base_midblock(h_base, **joint_args) # apply base mid block + if apply_control: + h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block + h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base + + return h_base, h_ctrl + + +class ControlNetXSCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], + temb_channels: int, + norm_num_groups: int = 32, + resolution_idx: Optional[int] = None, + has_crossattn=True, + transformer_layers_per_block: int = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1024, + add_upsample: bool = True, + upcast_attention: bool = False, + use_linear_projection: Optional[bool] = True, + ): + super().__init__() + resnets = [] + attentions = [] + ctrl_to_base = [] + + num_layers = 3 # only support sd + sdxl + + self.has_cross_attention = has_crossattn + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers + self.ctrl_to_base = nn.ModuleList(ctrl_to_base) + + if add_upsample: + self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + @classmethod + def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter): + ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base + + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + out_channels = base_upblock.resnets[0].out_channels + in_channels = base_upblock.resnets[-1].in_channels - out_channels + prev_output_channels = base_upblock.resnets[0].in_channels - out_channels + ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] + temb_channels = base_upblock.resnets[0].time_emb_proj.in_features + num_groups = base_upblock.resnets[0].norm1.num_groups + resolution_idx = base_upblock.resolution_idx + if hasattr(base_upblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) + num_attention_heads = get_first_cross_attention(base_upblock).heads + cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_upblock).upcast_attention + use_linear_projection = base_upblock.attentions[0].use_linear_projection + else: + has_crossattn = False + transformer_layers_per_block = None + num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + use_linear_projection = None + add_upsample = base_upblock.upsamplers is not None + + # create model + model = cls( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channels, + ctrl_skip_channels=ctrl_skip_channelss, + temb_channels=temb_channels, + norm_num_groups=num_groups, + resolution_idx=resolution_idx, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ) + + # load weights + model.resnets.load_state_dict(base_upblock.resnets.state_dict()) + if has_crossattn: + model.attentions.load_state_dict(base_upblock.attentions.state_dict()) + if add_upsample: + model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict()) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.resnets] + if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.attentions) + if self.upsamplers is not None: + base_parts.append(self.upsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states: Tensor, + res_hidden_states_tuple_base: Tuple[Tensor, ...], + res_hidden_states_tuple_ctrl: Tuple[Tensor, ...], + temb: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[Tensor] = None, + upsample_size: Optional[int] = None, + encoder_attention_mask: Optional[Tensor] = None, + apply_control: bool = True, + ) -> Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + return apply_freeu( + self.resolution_idx, + hidden_states, + res_h_base, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + else: + return hidden_states, res_h_base + + for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( + self.resnets, + self.attentions, + self.ctrl_to_base, + reversed(res_hidden_states_tuple_base), + reversed(res_hidden_states_tuple_ctrl), + ): + if apply_control: + hidden_states += c2b(res_h_ctrl) * conditioning_scale + + hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) + hidden_states = torch.cat([hidden_states, res_h_base], dim=1) + + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + if attn is not None: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + hidden_states = self.upsamplers(hidden_states, upsample_size) + + return hidden_states + + +def make_zero_conv(in_channels, out_channels=None): + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +def find_largest_factor(number, max_factor): + factor = max_factor + if factor >= number: + return number + while factor != 0: + residual = number % factor + if residual == 0: + return factor + factor -= 1 diff --git a/diffusers/src/diffusers/models/downsampling.py b/diffusers/src/diffusers/models/downsampling.py new file mode 100755 index 0000000..3ac8953 --- /dev/null +++ b/diffusers/src/diffusers/models/downsampling.py @@ -0,0 +1,401 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import deprecate +from .normalization import RMSNorm +from .upsampling import upfirdn2d_native + + +class Downsample1D(nn.Module): + """A 1D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 1D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + assert inputs.shape[1] == self.channels + return self.conv(inputs) + + +class Downsample2D(nn.Module): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + if use_conv: + conv = nn.Conv2d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class FirDownsample2D(nn.Module): + """A 2D FIR downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__( + self, + channels: Optional[int] = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d( + self, + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + kernel: Optional[torch.Tensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.Tensor: + """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.Tensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, convH, convW = weight.shape + pad_value = (kernel.shape[0] - factor) + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + + return output + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.use_conv: + downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return hidden_states + + +# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead +class KDownsample2D(nn.Module): + r"""A 2D K-downsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv2d(inputs, weight, stride=2) + + +class CogVideoXDownsample3D(nn.Module): + # Todo: Wait for paper relase. + r""" + A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `2`): + Stride of the convolution. + padding (`int`, defaults to `0`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 2, + padding: int = 0, + compress_time: bool = False, + ): + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.compress_time: + batch_size, channels, frames, height, width = x.shape + + # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) + x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) + + if x.shape[-1] % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + + x = torch.cat([x_first[..., None], x_rest], dim=-1) + # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) + x = F.avg_pool1d(x, kernel_size=2, stride=2) + # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + + # Pad the tensor + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + batch_size, channels, frames, height, width = x.shape + # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) + x = self.conv(x) + # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) + x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + +def downsample_2d( + hidden_states: torch.Tensor, + kernel: Optional[torch.Tensor] = None, + factor: int = 2, + gain: float = 1, +) -> torch.Tensor: + r"""Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states (`torch.Tensor`) + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output diff --git a/diffusers/src/diffusers/models/embeddings.py b/diffusers/src/diffusers/models/embeddings.py new file mode 100755 index 0000000..c4a80bc --- /dev/null +++ b/diffusers/src/diffusers/models/embeddings.py @@ -0,0 +1,1742 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import deprecate +from .activations import FP32SiLU, get_activation +from .attention_processor import Attention + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, +) -> np.ndarray: + r""" + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + return pos_embed + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for SD3 cropping.""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + + # Calculate positional embeddings based on max size or default + if pos_embed_max_size: + grid_size = pos_embed_max_size + else: + grid_size = int(num_patches**0.5) + + if pos_embed_type is None: + self.pos_embed = None + elif pos_embed_type == "sincos": + pos_embed = get_2d_sincos_pos_embed( + embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + persistent = True if pos_embed_max_size else False + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + else: + raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError("`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward(self, latent): + if self.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if self.pos_embed_max_size: + pos_embed = self.cropped_pos_embed(height, width) + else: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + return (latent + pos_embed).to(latent.dtype) + + +class LuminaPatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for Lumina-T2X""" + + def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=embed_dim, + bias=bias, + ) + + def forward(self, x, freqs_cis): + """ + Patchifies and embeds the input tensor(s). + + Args: + x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified + and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the + frequency tensor(s). + """ + freqs_cis = freqs_cis.to(x[0].device) + patch_height = patch_width = self.patch_size + batch_size, channel, height, width = x.size() + height_tokens, width_tokens = height // patch_height, width // patch_width + + x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( + 0, 2, 4, 1, 3, 5 + ) + x = x.flatten(3) + x = self.proj(x) + x = x.flatten(1, 2) + + mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) + + return ( + x, + mask, + [(height, width)] * batch_size, + freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), + ) + + +class CogVideoXPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, + use_learned_positional_embeddings: bool = True, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros( + 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) + + return joint_pos_embedding + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor, masks: torch.Tensor = None): + r""" + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + + batch, num_frames, channels, height, width = image_embeds.shape + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + + if masks is not None: + # Process mask + masks = masks.reshape(-1, 1, height, width) + masks = F.avg_pool2d(masks, kernel_size=self.patch_size, stride=self.patch_size) + masks = masks.view(batch, num_frames, *masks.shape[1:]) + masks = masks.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + masks = masks.flatten(1, 2) # [batch, num_frames x height x width, channels] + # BUG: patchsize = (2, 2)时,patch中大于等于3个pixel为1,则mask为1, 会导致mask缩小,从而分支多复制额外的mask区域,导致伪影 + # masks = (masks > 0.5).bool() + masks = (masks > 0.0).bool() + + embeds = torch.cat( + [text_embeds, image_embeds], dim=1 + ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): + raise ValueError( + "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." + "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + if ( + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + else: + pos_embedding = self.pos_embedding + + embeds = embeds + pos_embedding + if masks is not None: + return embeds, masks + return embeds + + +def get_3d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, grid[0].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, grid[1].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): + assert embed_dim % 4 == 0 + + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (H, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (W, D/4) + emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) + emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) + + emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class TextImageProjection(nn.Module): + def __init__( + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + batch_size = text_embeds.shape[0] + + # image + image_text_embeds = self.image_embeds(image_embeds) + image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + + # text + text_embeds = self.text_proj(text_embeds) + + return torch.cat([image_text_embeds, text_embeds], dim=1) + + +class ImageProjection(nn.Module): + def __init__( + self, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 32, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + batch_size = image_embeds.shape[0] + + # image + image_embeds = self.image_embeds(image_embeds) + image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + image_embeds = self.norm(image_embeds) + return image_embeds + + +class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + super().__init__() + from .attention import FeedForward + + self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + return self.norm(self.ff(image_embeds)) + + +class IPAdapterFaceIDImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__( + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + + # Here we use a default learned embedder layer for future extension. + self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size + if use_style_cond_and_image_meta_size: + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + else: + extra_in_dim = pooled_projection_dim + + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + if self.use_style_cond_and_image_meta_size: + # extra condition2: image meta size embedding + image_meta_size = self.size_proj(image_meta_size.view(-1)) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + else: + extra_cond = torch.cat([pooled_projections], dim=1) + + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning + + +class LuminaCombinedTimestepCaptionEmbedding(nn.Module): + def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + + self.caption_embedder = nn.Sequential( + nn.LayerNorm(cross_attention_dim), + nn.Linear( + cross_attention_dim, + hidden_size, + bias=True, + ), + ) + + def forward(self, timestep, caption_feat, caption_mask): + # timestep embedding: + time_freq = self.time_proj(timestep) + time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + + # caption condition embedding: + caption_mask_float = caption_mask.float().unsqueeze(-1) + caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) + caption_feats_pool = caption_feats_pool.to(caption_feat) + caption_embed = self.caption_embedder(caption_feats_pool) + + conditioning = time_embed + caption_embed + + return conditioning + + +class TextTimeEmbedding(nn.Module): + def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): + super().__init__() + self.norm1 = nn.LayerNorm(encoder_dim) + self.pool = AttentionPooling(num_heads, encoder_dim) + self.proj = nn.Linear(encoder_dim, time_embed_dim) + self.norm2 = nn.LayerNorm(time_embed_dim) + + def forward(self, hidden_states): + hidden_states = self.norm1(hidden_states) + hidden_states = self.pool(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class TextImageTimeEmbedding(nn.Module): + def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) + self.text_norm = nn.LayerNorm(time_embed_dim) + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + # text + time_text_embeds = self.text_proj(text_embeds) + time_text_embeds = self.text_norm(time_text_embeds) + + # image + time_image_embeds = self.image_proj(image_embeds) + + return time_image_embeds + time_text_embeds + + +class ImageTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + + def forward(self, image_embeds: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + return time_image_embeds + + +class ImageHintTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(256, 4, 3, padding=1), + ) + + def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + hint = self.input_hint_block(hint) + return time_image_embeds, hint + + +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) + + # (bs*n_heads, class_token_length, dim_per_head) + q = shape(self.q_proj(class_token)) + # (bs*n_heads, length+class_token_length, dim_per_head) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + # (bs*n_heads, class_token_length, length+class_token_length): + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # (bs*n_heads, dim_per_head, class_token_length) + a = torch.einsum("bts,bcs->bct", weight, v) + + # (bs, length+1, width) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] # cls_token + + +def get_fourier_embeds_from_boundingbox(embed_dim, box): + """ + Args: + embed_dim: int + box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline + Returns: + [B x N x embed_dim] tensor of positional embeddings + """ + + batch_size, num_boxes = box.shape[:2] + + emb = 100 ** (torch.arange(embed_dim) / embed_dim) + emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) + emb = emb * box.unsqueeze(-1) + + emb = torch.stack((emb.sin(), emb.cos()), dim=-1) + emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) + + return emb + + +class GLIGENTextBoundingboxProjection(nn.Module): + def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): + super().__init__() + self.positive_len = positive_len + self.out_dim = out_dim + + self.fourier_embedder_dim = fourier_freqs + self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy + + if isinstance(out_dim, tuple): + out_dim = out_dim[0] + + if feature_type == "text-only": + self.linears = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + elif feature_type == "text-image": + self.linears_text = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.linears_image = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + def forward( + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, + ): + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C + + # learnable null embedding + xyxy_null = self.null_position_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + # positionet with text only information + if positive_embeddings is not None: + # learnable null embedding + positive_null = self.null_positive_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null + + objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + + # positionet with text and image information + else: + phrases_masks = phrases_masks.unsqueeze(-1) + image_masks = image_masks.unsqueeze(-1) + + # learnable null embedding + text_null = self.null_text_feature.view(1, 1, -1) + image_null = self.null_image_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null + image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null + + objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) + objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) + objs = torch.cat([objs_text, objs_image], dim=1) + + return objs + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + elif act_fn == "silu_fp32": + self.act_1 = FP32SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class IPAdapterPlusImageProjectionBlock(nn.Module): + def __init__( + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(embed_dims) + self.ln1 = nn.LayerNorm(embed_dims) + self.attn = Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ) + self.ff = nn.Sequential( + nn.LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ) + + def forward(self, x, latents, residual): + encoder_hidden_states = self.ln0(x) + latents = self.ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = self.attn(latents, encoder_hidden_states) + residual + latents = self.ff(latents) + latents + return latents + + +class IPAdapterPlusImageProjection(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (torch.Tensor): Input Tensor. + Returns: + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class IPAdapterFaceIDPlusImageProjection(nn.Module): + """FacePerceiverResampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + ffproj_ratio (float): The expansion ratio of feedforward network hidden + layer channels (for ID embeddings). Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.embed_dim = embed_dims + self.clip_embeds = None + self.shortcut = False + self.shortcut_scale = 1.0 + + self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.norm = nn.LayerNorm(embed_dims) + + self.proj_in = nn.Linear(hidden_dims, embed_dims) + + self.proj_out = nn.Linear(embed_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + id_embeds (torch.Tensor): Input Tensor (ID embeds). + Returns: + torch.Tensor: Output Tensor. + """ + id_embeds = id_embeds.to(self.clip_embeds.dtype) + id_embeds = self.proj(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) + id_embeds = self.norm(id_embeds) + latents = id_embeds + + clip_embeds = self.proj_in(self.clip_embeds) + x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + out = self.norm_out(latents) + if self.shortcut: + out = id_embeds + self.shortcut_scale * out + return out + + +class MultiIPAdapterImageProjection(nn.Module): + def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): + super().__init__() + self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) + + def forward(self, image_embeds: List[torch.Tensor]): + projected_image_embeds = [] + + # currently, we accept `image_embeds` as + # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] + # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] + if not isinstance(image_embeds, list): + deprecation_message = ( + "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." + ) + deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) + image_embeds = [image_embeds.unsqueeze(1)] + + if len(image_embeds) != len(self.image_projection_layers): + raise ValueError( + f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" + ) + + for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): + batch_size, num_images = image_embed.shape[0], image_embed.shape[1] + image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + image_embed = image_projection_layer(image_embed) + image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) + + projected_image_embeds.append(image_embed) + + return projected_image_embeds diff --git a/diffusers/src/diffusers/models/embeddings_flax.py b/diffusers/src/diffusers/models/embeddings_flax.py new file mode 100755 index 0000000..8e343be --- /dev/null +++ b/diffusers/src/diffusers/models/embeddings_flax.py @@ -0,0 +1,97 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import flax.linen as nn +import jax.numpy as jnp + + +def get_sinusoidal_embeddings( + timesteps: jnp.ndarray, + embedding_dim: int, + freq_shift: float = 1, + min_timescale: float = 1, + max_timescale: float = 1.0e4, + flip_sin_to_cos: bool = False, + scale: float = 1.0, +) -> jnp.ndarray: + """Returns the positional encoding (same as Tensor2Tensor). + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + embedding_dim: The number of output channels. + min_timescale: The smallest time unit (should probably be 0.0). + max_timescale: The largest time unit. + Returns: + a Tensor of timing signals [N, num_channels] + """ + assert timesteps.ndim == 1, "Timesteps should be a 1d-array" + assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" + num_timescales = float(embedding_dim // 2) + log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) + inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) + emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) + + # scale embeddings + scaled_time = scale * emb + + if flip_sin_to_cos: + signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) + else: + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) + signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) + return signal + + +class FlaxTimestepEmbedding(nn.Module): + r""" + Time step Embedding Module. Learns embeddings for input time steps. + + Args: + time_embed_dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + time_embed_dim: int = 32 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, temb): + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) + temb = nn.silu(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) + return temb + + +class FlaxTimesteps(nn.Module): + r""" + Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 + + Args: + dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + """ + + dim: int = 32 + flip_sin_to_cos: bool = False + freq_shift: float = 1 + + @nn.compact + def __call__(self, timesteps): + return get_sinusoidal_embeddings( + timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift + ) diff --git a/diffusers/src/diffusers/models/lora.py b/diffusers/src/diffusers/models/lora.py new file mode 100755 index 0000000..4e9e0c0 --- /dev/null +++ b/diffusers/src/diffusers/models/lora.py @@ -0,0 +1,457 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTANT: # +################################################################### +# ----------------------------------------------------------------# +# This file is deprecated and will be removed soon # +# (as soon as PEFT will become a required dependency for LoRA) # +# ----------------------------------------------------------------# +################################################################### + +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import deprecate, logging +from ..utils.import_utils import is_transformers_available + + +if is_transformers_available(): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def text_encoder_attn_modules(text_encoder): + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_mlp_modules(text_encoder): + mlp_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + mlp_mod = layer.mlp + name = f"text_model.encoder.layers.{i}.mlp" + mlp_modules.append((name, mlp_mod)) + else: + raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") + + return mlp_modules + + +def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_scale = lora_scale + attn_module.k_proj.lora_scale = lora_scale + attn_module.v_proj.lora_scale = lora_scale + attn_module.out_proj.lora_scale = lora_scale + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_scale = lora_scale + mlp_module.fc2.lora_scale = lora_scale + + +class PatchedLoraProjection(torch.nn.Module): + def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): + deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("PatchedLoraProjection", "1.0.0", deprecation_message) + + super().__init__() + from ..models.lora import LoRALinearLayer + + self.regular_linear_layer = regular_linear_layer + + device = self.regular_linear_layer.weight.device + + if dtype is None: + dtype = self.regular_linear_layer.weight.dtype + + self.lora_linear_layer = LoRALinearLayer( + self.regular_linear_layer.in_features, + self.regular_linear_layer.out_features, + network_alpha=network_alpha, + device=device, + dtype=dtype, + rank=rank, + ) + + self.lora_scale = lora_scale + + # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved + # when saving the whole text encoder model and when LoRA is unloaded or fused + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if self.lora_linear_layer is None: + return self.regular_linear_layer.state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + + def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): + if self.lora_linear_layer is None: + return + + dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device + + w_orig = self.regular_linear_layer.weight.data.float() + w_up = self.lora_linear_layer.up.weight.data.float() + w_down = self.lora_linear_layer.down.weight.data.float() + + if self.lora_linear_layer.network_alpha is not None: + w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank + + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + + if safe_fusing and torch.isnan(fused_weight).any().item(): + raise ValueError( + "This LoRA weight seems to be broken. " + f"Encountered NaN values when trying to fuse LoRA weights for {self}." + "LoRA weights will not be fused." + ) + + self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_linear_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self.lora_scale = lora_scale + + def _unfuse_lora(self): + if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): + return + + fused_weight = self.regular_linear_layer.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, input): + if self.lora_scale is None: + self.lora_scale = 1.0 + if self.lora_linear_layer is None: + return self.regular_linear_layer(input) + return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input)) + + +class LoRALinearLayer(nn.Module): + r""" + A linear layer that is used with LoRA. + + Parameters: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + rank (`int`, `optional`, defaults to 4): + The rank of the LoRA layer. + network_alpha (`float`, `optional`, defaults to `None`): + The value of the network alpha used for stable learning and preventing underflow. This value has the same + meaning as the `--network_alpha` option in the kohya-ss trainer script. See + https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + device (`torch.device`, `optional`, defaults to `None`): + The device to use for the layer's weights. + dtype (`torch.dtype`, `optional`, defaults to `None`): + The dtype to use for the layer's weights. + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 4, + network_alpha: Optional[float] = None, + device: Optional[Union[torch.device, str]] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + + deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRALinearLayer", "1.0.0", deprecation_message) + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + self.out_features = out_features + self.in_features = in_features + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRAConv2dLayer(nn.Module): + r""" + A convolutional layer that is used with LoRA. + + Parameters: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + rank (`int`, `optional`, defaults to 4): + The rank of the LoRA layer. + kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1): + The kernel size of the convolution. + stride (`int` or `tuple` of two `int`, `optional`, defaults to 1): + The stride of the convolution. + padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0): + The padding of the convolution. + network_alpha (`float`, `optional`, defaults to `None`): + The value of the network alpha used for stable learning and preventing underflow. This value has the same + meaning as the `--network_alpha` option in the kohya-ss trainer script. See + https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 4, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, int], str] = 0, + network_alpha: Optional[float] = None, + ): + super().__init__() + + deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message) + + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + # according to the official kohya_ss trainer kernel_size are always fixed for the up layer + # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 + self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) + + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRACompatibleConv(nn.Conv2d): + """ + A convolutional layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): + deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRACompatibleConv", "1.0.0", deprecation_message) + + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("set_lora_layer", "1.0.0", deprecation_message) + + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) + fusion = fusion.reshape((w_orig.shape)) + fused_weight = w_orig + (lora_scale * fusion) + + if safe_fusing and torch.isnan(fused_weight).any().item(): + raise ValueError( + "This LoRA weight seems to be broken. " + f"Encountered NaN values when trying to fuse LoRA weights for {self}." + "LoRA weights will not be fused." + ) + + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.data.dtype, fused_weight.data.device + + self.w_up = self.w_up.to(device=device).float() + self.w_down = self.w_down.to(device).float() + + fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) + fusion = fusion.reshape((fused_weight.shape)) + unfused_weight = fused_weight.float() - (self._lora_scale * fusion) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + if self.padding_mode != "zeros": + hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) + padding = (0, 0) + else: + padding = self.padding + + original_outputs = F.conv2d( + hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups + ) + + if self.lora_layer is None: + return original_outputs + else: + return original_outputs + (scale * self.lora_layer(hidden_states)) + + +class LoRACompatibleLinear(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message) + + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("set_lora_layer", "1.0.0", deprecation_message) + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + + if safe_fusing and torch.isnan(fused_weight).any().item(): + raise ValueError( + "This LoRA weight seems to be broken. " + f"Encountered NaN values when trying to fuse LoRA weights for {self}." + "LoRA weights will not be fused." + ) + + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + if self.lora_layer is None: + out = super().forward(hidden_states) + return out + else: + out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + return out diff --git a/diffusers/src/diffusers/models/model_loading_utils.py b/diffusers/src/diffusers/models/model_loading_utils.py new file mode 100755 index 0000000..969eb5f --- /dev/null +++ b/diffusers/src/diffusers/models/model_loading_utils.py @@ -0,0 +1,230 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from collections import OrderedDict +from pathlib import Path +from typing import List, Optional, Union + +import safetensors +import torch +from huggingface_hub.utils import EntryNotFoundError + +from ..utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_FILE_EXTENSION, + WEIGHTS_INDEX_NAME, + _add_variant, + _get_model_file, + is_accelerate_available, + is_torch_version, + logging, +) + + +logger = logging.get_logger(__name__) + +_CLASS_REMAPPING_DICT = { + "Transformer2DModel": { + "ada_norm_zero": "DiTTransformer2DModel", + "ada_norm_single": "PixArtTransformer2DModel", + } +} + + +if is_accelerate_available(): + from accelerate import infer_auto_device_map + from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device + + +# Adapted from `transformers` (see modeling_utils.py) +def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): + if isinstance(device_map, str): + no_split_modules = model._get_no_split_modules(device_map) + device_map_kwargs = {"no_split_module_classes": no_split_modules} + + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=torch_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + + device_map_kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + + return device_map + + +def _fetch_remapped_cls_from_config(config, old_class): + previous_class_name = old_class.__name__ + remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) + + # Details: + # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818 + if remapped_class_name: + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + remapped_class = getattr(diffusers_library, remapped_class_name) + logger.info( + f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type." + f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this" + " DOESN'T affect the final results." + ) + return remapped_class + else: + return old_class + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + file_extension = os.path.basename(checkpoint_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: + return safetensors.torch.load_file(checkpoint_file, device="cpu") + else: + weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} + return torch.load( + checkpoint_file, + map_location="cpu", + **weights_only_kwarg, + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " + ) + + +def load_model_dict_into_meta( + model, + state_dict: OrderedDict, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + model_name_or_path: Optional[str] = None, +) -> List[str]: + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + + unexpected_keys = [] + empty_state_dict = model.state_dict() + for param_name, param in state_dict.items(): + if param_name not in empty_state_dict: + unexpected_keys.append(param_name) + continue + + if empty_state_dict[param_name].shape != param.shape: + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" + raise ValueError( + f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + else: + set_module_tensor_to_device(model, param_name, device, value=param) + return unexpected_keys + + +def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix: str = ""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs + + +def _fetch_index_file( + is_local, + pretrained_model_name_or_path, + subfolder, + use_safetensors, + cache_dir, + variant, + force_download, + proxies, + local_files_only, + token, + revision, + user_agent, + commit_hash, +): + if is_local: + index_file = Path( + pretrained_model_name_or_path, + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ) + else: + index_file_in_repo = Path( + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ).as_posix() + try: + index_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=index_file_in_repo, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=None, + user_agent=user_agent, + commit_hash=commit_hash, + ) + index_file = Path(index_file) + except (EntryNotFoundError, EnvironmentError): + index_file = None + + return index_file diff --git a/diffusers/src/diffusers/models/modeling_flax_pytorch_utils.py b/diffusers/src/diffusers/models/modeling_flax_pytorch_utils.py new file mode 100755 index 0000000..4db537f --- /dev/null +++ b/diffusers/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch - Flax general utilities.""" + +import re + +import jax.numpy as jnp +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +def rename_key(key): + regex = r"\w+[.]\d+" + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + +##################### +# PyTorch => Flax # +##################### + + +# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 +# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): + """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" + # conv norm or layer norm + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + + # rename attention layers + if len(pt_tuple_key) > 1: + for rename_from, rename_to in ( + ("to_out_0", "proj_attn"), + ("to_k", "key"), + ("to_v", "value"), + ("to_q", "query"), + ): + if pt_tuple_key[-2] == rename_from: + weight_name = pt_tuple_key[-1] + weight_name = "kernel" if weight_name == "weight" else weight_name + renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) + if renamed_pt_tuple_key in random_flax_state_dict: + assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape + return renamed_pt_tuple_key, pt_tensor.T + + if ( + any("norm" in str_ for str_ in pt_tuple_key) + and (pt_tuple_key[-1] == "bias") + and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) + and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) + ): + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + + # embedding + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + return renamed_pt_tuple_key, pt_tensor + + # conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + return renamed_pt_tuple_key, pt_tensor + + # linear layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight": + pt_tensor = pt_tensor.T + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm weight + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + if pt_tuple_key[-1] == "gamma": + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm bias + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + if pt_tuple_key[-1] == "beta": + return renamed_pt_tuple_key, pt_tensor + + return pt_tuple_key, pt_tensor + + +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): + # Step 1: Convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + # Step 2: Since the model is stateless, get random Flax params + random_flax_params = flax_model.init_weights(PRNGKey(init_key)) + + random_flax_state_dict = flatten_dict(random_flax_params) + flax_state_dict = {} + + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + return unflatten_dict(flax_state_dict) diff --git a/diffusers/src/diffusers/models/modeling_flax_utils.py b/diffusers/src/diffusers/models/modeling_flax_utils.py new file mode 100755 index 0000000..8c35fab --- /dev/null +++ b/diffusers/src/diffusers/models/modeling_flax_utils.py @@ -0,0 +1,561 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pickle import UnpicklingError +from typing import Any, Dict, Union + +import jax +import jax.numpy as jnp +import msgpack.exceptions +from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict +from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + validate_hf_hub_args, +) +from requests import HTTPError + +from .. import __version__, is_torch_available +from ..utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + PushToHubMixin, + logging, +) +from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax + + +logger = logging.get_logger(__name__) + + +class FlaxModelMixin(PushToHubMixin): + r""" + Base class for all Flax models. + + [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. + + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. + """ + + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _flax_internal_args = ["name", "parent", "dtype"] + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_flatten(mask) + + for masked, key in zip(flat_mask, flat_params.keys()): + if masked: + param = flat_params[key] + flat_params[key] = conditional_cast(param) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast + the `params` in place. + + This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # load model + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> params = model.to_bf16(params) + >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> flat_params = traverse_util.flatten_dict(params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> params = model.to_bf16(params, mask) + ```""" + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # Download model and configuration from huggingface.co + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> params = model.to_f16(params) + >>> # now cast back to fp32 + >>> params = model.to_fp32(params) + ```""" + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + `params` in place. + + This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # load model + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> params = model.to_fp16(params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> flat_params = traverse_util.flatten_dict(params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> params = model.to_fp16(params, mask) + ```""" + return self._cast_floating_to(params, jnp.float16, mask) + + def init_weights(self, rng: jax.Array) -> Dict: + raise NotImplementedError(f"init_weights method has to be implemented for {self}") + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + **kwargs, + ): + r""" + Instantiate a pretrained Flax model from a pretrained model configuration. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model + hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + using [`~FlaxModelMixin.save_pretrained`]. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified, all the computation will be performed with the given `dtype`. + + + + This only specifies the dtype of the *computation* and does not influence the dtype of model + parameters. + + If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and + [`~FlaxModelMixin.to_bf16`]. + + + + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments are passed to the underlying model's `__init__` method. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it is loaded) and initiate the model (for + example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying + model's `__init__` method (we assume all relevant updates to the configuration have already been + done). + - If a configuration is not provided, `kwargs` are first passed to the configuration class + initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds + to a configuration attribute is used to override said attribute with the supplied `kwargs` value. + Remaining keys that do not correspond to any configuration attribute are passed to the underlying + model's `__init__` function. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + from_pt = kwargs.pop("from_pt", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "flax", + } + + # Load config if we don't provide one + if config is None: + config, unused_kwargs = cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + **kwargs, + ) + + model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs) + + # Load model + pretrained_path_with_subfolder = ( + pretrained_model_name_or_path + if subfolder is None + else os.path.join(pretrained_model_name_or_path, subfolder) + ) + if os.path.isdir(pretrained_path_with_subfolder): + if from_pt: + if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " + ) + model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) + # Check if pytorch weights exist instead + elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): + raise EnvironmentError( + f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" + " using `from_pt=True`." + ) + else: + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_path_with_subfolder}." + ) + else: + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" + " internet connection or see how to run the library in offline mode at" + " 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + + if from_pt: + if is_torch_available(): + from .modeling_utils import load_state_dict + else: + raise EnvironmentError( + "Can't load the model in PyTorch format because PyTorch is not installed. " + "Please, install PyTorch or use native Flax weights." + ) + + # Step 1: Get the pytorch file + pytorch_model_file = load_state_dict(model_file) + + # Step 2: Convert the weights + state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) + else: + try: + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.ndarray + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state) + + # flatten dicts + state = flatten_dict(state) + + params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) + required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + shape_state = flatten_dict(unfreeze(params_shape_tree)) + + missing_keys = required_params - set(state.keys()) + unexpected_keys = set(state.keys()) - required_params + + if missing_keys: + logger.warning( + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " + "Make sure to call model.init_weights to initialize the missing weights." + ) + cls._missing_keys = missing_keys + + for key in state.keys(): + if key in shape_state and state[key].shape != shape_state[key].shape: + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " + ) + + # remove unexpected keys to not be saved again + for unexpected_key in unexpected_keys: + del state[unexpected_key] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + + return model, unflatten_dict(state) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params: Union[Dict, FrozenDict], + is_main_process: bool = True, + push_to_hub: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~FlaxModelMixin.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a model and its configuration file to. Will be created if it doesn't exist. + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # save model + output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) + with open(output_model_file, "wb") as f: + model_bytes = to_bytes(params) + f.write(model_bytes) + + logger.info(f"Model weights saved in {output_model_file}") + + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) diff --git a/diffusers/src/diffusers/models/modeling_outputs.py b/diffusers/src/diffusers/models/modeling_outputs.py new file mode 100755 index 0000000..0120a34 --- /dev/null +++ b/diffusers/src/diffusers/models/modeling_outputs.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +from ..utils import BaseOutput + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" # noqa: F821 + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: "torch.Tensor" # noqa: F821 diff --git a/diffusers/src/diffusers/models/modeling_pytorch_flax_utils.py b/diffusers/src/diffusers/models/modeling_pytorch_flax_utils.py new file mode 100755 index 0000000..55eff0e --- /dev/null +++ b/diffusers/src/diffusers/models/modeling_pytorch_flax_utils.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch - Flax general utilities.""" + +from pickle import UnpicklingError + +import jax +import jax.numpy as jnp +import numpy as np +from flax.serialization import from_bytes +from flax.traverse_util import flatten_dict + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +##################### +# Flax => PyTorch # +##################### + + +# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 +def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): + try: + with open(model_file, "rb") as flax_state_f: + flax_state = from_bytes(None, flax_state_f.read()) + except UnpicklingError as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + + return load_flax_weights_in_pytorch_model(pt_model, flax_state) + + +def load_flax_weights_in_pytorch_model(pt_model, flax_state): + """Load flax checkpoints in a PyTorch model""" + + try: + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise + + # check if we have bf16 weights + is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + if any(is_type_bf16): + # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 + + # and bf16 is not fully supported in PT yet. + logger.warning( + "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " + "before loading those in PyTorch model." + ) + flax_state = jax.tree_util.tree_map( + lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state + ) + + pt_model.base_model_prefix = "" + + flax_state_dict = flatten_dict(flax_state, sep=".") + pt_model_dict = pt_model.state_dict() + + # keep track of unexpected & missing keys + unexpected_keys = [] + missing_keys = set(pt_model_dict.keys()) + + for flax_key_tuple, flax_tensor in flax_state_dict.items(): + flax_key_tuple_array = flax_key_tuple.split(".") + + if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) + elif flax_key_tuple_array[-1] == "kernel": + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + flax_tensor = flax_tensor.T + elif flax_key_tuple_array[-1] == "scale": + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + + if "time_embedding" not in flax_key_tuple_array: + for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): + flax_key_tuple_array[i] = ( + flax_key_tuple_string.replace("_0", ".0") + .replace("_1", ".1") + .replace("_2", ".2") + .replace("_3", ".3") + .replace("_4", ".4") + .replace("_5", ".5") + .replace("_6", ".6") + .replace("_7", ".7") + .replace("_8", ".8") + .replace("_9", ".9") + ) + + flax_key = ".".join(flax_key_tuple_array) + + if flax_key in pt_model_dict: + if flax_tensor.shape != pt_model_dict[flax_key].shape: + raise ValueError( + f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " + f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + else: + # add weight to pytorch dict + flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor + pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) + # remove from missing keys + missing_keys.remove(flax_key) + else: + # weight is not expected by PyTorch model + unexpected_keys.append(flax_key) + + pt_model.load_state_dict(pt_model_dict) + + # re-transform missing_keys to list + missing_keys = list(missing_keys) + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the Flax model were not used when initializing the PyTorch model" + f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" + " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" + f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" + " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" + " FlaxBertForSequenceClassification model)." + ) + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" + f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" + " use it for predictions and inference." + ) + + return pt_model diff --git a/diffusers/src/diffusers/models/modeling_utils.py b/diffusers/src/diffusers/models/modeling_utils.py new file mode 100755 index 0000000..cfe692d --- /dev/null +++ b/diffusers/src/diffusers/models/modeling_utils.py @@ -0,0 +1,1203 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import itertools +import json +import os +import re +from collections import OrderedDict +from functools import partial +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +import safetensors +import torch +from huggingface_hub import create_repo, split_torch_state_dict_into_shards +from huggingface_hub.utils import validate_hf_hub_args +from torch import Tensor, nn + +from .. import __version__ +from ..utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + deprecate, + is_accelerate_available, + is_torch_version, + logging, +) +from ..utils.hub_utils import ( + PushToHubMixin, + load_or_create_model_card, + populate_model_card, +) +from .model_loading_utils import ( + _determine_device_map, + _fetch_index_file, + _load_state_dict_into_model, + load_model_dict_into_meta, + load_state_dict, +) + + +logger = logging.get_logger(__name__) + +_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + +if is_accelerate_available(): + import accelerate + + +def get_parameter_device(parameter: torch.nn.Module) -> torch.device: + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: + try: + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0].dtype + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0].dtype + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +class ModelMixin(torch.nn.Module, PushToHubMixin): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. + + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. + """ + + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None + + def __init__(self): + super().__init__() + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite + __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + return super().__getattr__(name) + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self) -> None: + """ + Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + """ + if not self._supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self) -> None: + """ + Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def set_use_npu_flash_attention(self, valid: bool) -> None: + r""" + Set the switch for the npu flash attention. + """ + + def fn_recursive_set_npu_flash_attention(module: torch.nn.Module): + if hasattr(module, "set_use_npu_flash_attention"): + module.set_use_npu_flash_attention(valid) + + for child in module.children(): + fn_recursive_set_npu_flash_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_npu_flash_attention(module) + + def enable_npu_flash_attention(self) -> None: + r""" + Enable npu flash attention from torch_npu + + """ + self.set_use_npu_flash_attention(True) + + def disable_npu_flash_attention(self) -> None: + r""" + disable npu flash attention from torch_npu + + """ + self.set_use_npu_flash_attention(False) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None: + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up during + inference. Speed up during training is not guaranteed. + + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import UNet2DConditionModel + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> model = UNet2DConditionModel.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 + ... ) + >>> model = model.to("cuda") + >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self) -> None: + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Optional[Callable] = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + max_shard_size: Union[int, str] = "10GB", + push_to_hub: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~models.ModelMixin.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a model and its configuration file to. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain + period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`. + This is to establish a common default size for this argument across different libraries in the Hugging + Face ecosystem (`transformers`, and `accelerate`, for example). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + weight_name_split = weights_name.split(".") + if len(weight_name_split) in [2, 3]: + weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) + else: + raise ValueError(f"Invalid {weights_name} provided.") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + # Only save the model itself if we are using distributed training + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + # Save the model + state_dict_split = split_torch_state_dict_into_shards( + state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern + ) + + # Clean the folder from a previous save + if is_main_process: + for filename in os.listdir(save_directory): + if filename in state_dict_split.filename_to_tensors.keys(): + continue + full_filename = os.path.join(save_directory, filename) + if not os.path.isfile(full_filename): + continue + weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") + weights_without_ext = weights_without_ext.replace("{suffix}", "") + filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + if ( + filename.startswith(weights_without_ext) + and _REGEX_SHARD.fullmatch(filename_without_ext) is not None + ): + os.remove(full_filename) + + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + filepath = os.path.join(save_directory, filename) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + else: + torch.save(shard, filepath) + + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + else: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + + if push_to_hub: + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) + model_card.save(Path(save_directory, "README.md").as_posix()) + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. Defaults to `None`, meaning that the model will be loaded on CPU. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # change device_map into a map if we passed an int, a str or a torch.device + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + if low_cpu_mem_usage: + if device_map is not None and not is_torch_version(">=", "1.10"): + # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. + raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + + # Determine if we're loading from a directory of sharded checkpoints. + is_sharded = False + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file = _fetch_index_file( + is_local=is_local, + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder or "", + use_safetensors=use_safetensors, + cache_dir=cache_dir, + variant=variant, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + user_agent=user_agent, + commit_hash=commit_hash, + ) + if index_file is not None and index_file.is_file(): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if is_sharded: + sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + ) + + elif use_safetensors and not is_sharded: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if model_file is None and not is_sharded: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None and not is_sharded: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by default the device_map is None and the weights are loaded on the CPU + force_hook = True + device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) + if device_map is None and is_sharded: + # we load the parameters on the cpu + device_map = {"": "cpu"} + force_hook = False + try: + accelerate.load_checkpoint_and_dispatch( + model, + model_file if not is_sharded else index_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + force_hooks=force_hook, + strict=True, + ) + except AttributeError as e: + # When using accelerate loading, we do not have the ability to load the state + # dict and rename the weight names manually. Additionally, accelerate skips + # torch loading conventions and directly writes into `module.{_buffers, _parameters}` + # (which look like they should be private variables?), so we can't use the standard hooks + # to rename parameters on load. We need to mimic the original weight names so the correct + # attributes are available. After we have loaded the weights, we convert the deprecated + # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert + # the weights so we don't have to do this again. + + if "'Attention' object has no attribute" in str(e): + logger.warning( + f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" + " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" + " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," + " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," + " please also re-upload it or open a PR on the original repository." + ) + model._temp_convert_self_to_deprecated_attention_blocks() + accelerate.load_checkpoint_and_dispatch( + model, + model_file if not is_sharded else index_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + force_hooks=force_hook, + strict=True, + ) + model._undo_temp_convert_self_to_deprecated_attention_blocks() + else: + raise e + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict: OrderedDict, + resolved_archive_file, + pretrained_model_name_or_path: Union[str, os.PathLike], + ignore_mismatched_sizes: bool = False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + # Adapted from `transformers` modeling_utils.py + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, ModelMixin): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.children()) + return list(_no_split_modules) + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (trainable or non-embedding) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters. + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embedding parameters. + + Returns: + `int`: The number of parameters. + + Example: + + ```py + from diffusers import UNet2DConditionModel + + model_id = "runwayml/stable-diffusion-v1-5" + unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") + unet.num_parameters(only_trainable=True) + 859520964 + ``` + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + # NOTE: we have to check if the deprecated parameters are in the state dict + # because it is possible we are loading from a state dict that was already + # converted + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") + + def _temp_convert_self_to_deprecated_attention_blocks(self) -> None: + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.query = module.to_q + module.key = module.to_k + module.value = module.to_v + module.proj_attn = module.to_out[0] + + # We don't _have_ to delete the old attributes, but it's helpful to ensure + # that _all_ the weights are loaded into the new attributes and we're not + # making an incorrect assumption that this model should be converted when + # it really shouldn't be. + del module.to_q + del module.to_k + del module.to_v + del module.to_out + + def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None: + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module) -> None: + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.to_q = module.query + module.to_k = module.key + module.to_v = module.value + module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) + + del module.query + del module.key + del module.value + del module.proj_attn + + +class LegacyModelMixin(ModelMixin): + r""" + A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + # To prevent dependency import problem. + from .model_loading_utils import _fetch_remapped_cls_from_config + + # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls. + kwargs_copy = kwargs.copy() + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, _, _ = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) diff --git a/diffusers/src/diffusers/models/normalization.py b/diffusers/src/diffusers/models/normalization.py new file mode 100755 index 0000000..84c2772 --- /dev/null +++ b/diffusers/src/diffusers/models/normalization.py @@ -0,0 +1,457 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import is_torch_version +from .activations import get_activation +from .embeddings import ( + CombinedTimestepLabelEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings, +) + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX for now. + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class AdaLayerNormZero(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): + super().__init__() + if num_embeddings is not None: + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + hidden_dtype: Optional[torch.dtype] = None, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZeroSingle(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # emb = self.emb(timestep, encoder_hidden_states, encoder_mask) + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + + return x, gate_msa, scale_mlp, gate_mlp + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class AdaGroupNorm(nn.Module): + r""" + GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. + """ + + def __init__( + self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + else: + raise ValueError(f"unknown norm_type {norm_type}") + # linear_2 + if out_dim is not None: + self.linear_2 = nn.Linear( + embedding_dim, + out_dim, + bias=bias, + ) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class CogVideoXLayerNormZero(nn.Module): + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + def forward_wo_text( + self, hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states, gate[:, None, :] + +if is_torch_version(">=", "2.1.0"): + LayerNorm = nn.LayerNorm +else: + # Has optional bias parameter compared to torch layer norm + # TODO: replace with torch layernorm once min required torch version >= 2.1 + class LayerNorm(nn.Module): + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) if bias else None + else: + self.weight = None + self.bias = None + + def forward(self, input): + return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +class GlobalResponseNorm(nn.Module): + # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * nx) + self.beta + x diff --git a/diffusers/src/diffusers/models/resnet.py b/diffusers/src/diffusers/models/resnet.py new file mode 100755 index 0000000..00b55cd --- /dev/null +++ b/diffusers/src/diffusers/models/resnet.py @@ -0,0 +1,797 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# `TemporalConvLayer` Copyright 2024 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import deprecate +from .activations import get_activation +from .attention_processor import SpatialNorm +from .downsampling import ( # noqa + Downsample1D, + Downsample2D, + FirDownsample2D, + KDownsample2D, + downsample_2d, +) +from .normalization import AdaGroupNorm +from .upsampling import ( # noqa + FirUpsample2D, + KUpsample2D, + Upsample1D, + Upsample2D, + upfirdn2d_native, + upsample_2d, +) + + +class ResnetBlockCondNorm2D(nn.Module): + r""" + A Resnet block that use normalization layer that incorporate conditioning information. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"ada_group"` ): + The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial". + kernel (`torch.Tensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "ada_group", # ada_group, spatial + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + if self.time_embedding_norm == "ada_group": # ada_group + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) + else: + raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if self.time_embedding_norm == "ada_group": # ada_group + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": # spatial + self.norm2 = SpatialNorm(out_channels, temb_channels) + else: + raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") + + self.dropout = torch.nn.Dropout(dropout) + + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states, temb) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, temb) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a + stronger conditioning with scale and shift. + kernel (`torch.Tensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, + kernel: Optional[torch.Tensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + if time_embedding_norm == "ada_group": + raise ValueError( + "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", + ) + if time_embedding_norm == "spatial": + raise ValueError( + "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", + ) + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if self.time_embedding_norm == "default": + if temb is not None: + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + elif self.time_embedding_norm == "scale_shift": + if temb is None: + raise ValueError( + f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" + ) + time_scale, time_shift = torch.chunk(temb, 2, dim=1) + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + time_scale) + time_shift + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +# unet_rl.py +def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + + Parameters: + inp_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + kernel_size (`int` or `tuple`): Size of the convolving kernel. + n_groups (`int`, default `8`): Number of groups to separate the channels into. + activation (`str`, defaults to `mish`): Name of the activation function. + """ + + def __init__( + self, + inp_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + n_groups: int = 8, + activation: str = "mish", + ): + super().__init__() + + self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = get_activation(activation) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + intermediate_repr = self.conv1d(inputs) + intermediate_repr = rearrange_dims(intermediate_repr) + intermediate_repr = self.group_norm(intermediate_repr) + intermediate_repr = rearrange_dims(intermediate_repr) + output = self.mish(intermediate_repr) + return output + + +# unet_rl.py +class ResidualTemporalBlock1D(nn.Module): + """ + Residual 1D block with temporal convolutions. + + Parameters: + inp_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + embed_dim (`int`): Embedding dimension. + kernel_size (`int` or `tuple`): Size of the convolving kernel. + activation (`str`, defaults `mish`): It is possible to choose the right activation function. + """ + + def __init__( + self, + inp_channels: int, + out_channels: int, + embed_dim: int, + kernel_size: Union[int, Tuple[int, int]] = 5, + activation: str = "mish", + ): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + + self.time_emb_act = get_activation(activation) + self.time_emb = nn.Linear(embed_dim, out_channels) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(inputs) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(inputs) + + +class TemporalConvLayer(nn.Module): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + + Parameters: + in_dim (`int`): Number of input channels. + out_dim (`int`): Number of output channels. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + """ + + def __init__( + self, + in_dim: int, + out_dim: Optional[int] = None, + dropout: float = 0.0, + norm_num_groups: int = 32, + ): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor: + hidden_states = ( + hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) + ) + + identity = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.conv3(hidden_states) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( + (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] + ) + return hidden_states + + +class TemporalResnetBlock(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + kernel_size = (3, 1, 1) + padding = [k // 2 for k in kernel_size] + + self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ) + + if temb_channels is not None: + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(0.0) + self.conv2 = nn.Conv3d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ) + + self.nonlinearity = get_activation("silu") + + self.use_in_shortcut = self.in_channels != out_channels + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, :, None, None] + temb = temb.permute(0, 2, 1, 3, 4) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +# VideoResBlock +class SpatioTemporalResBlock(nn.Module): + r""" + A SpatioTemporal Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet. + temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet. + merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing. + merge_strategy (`str`, *optional*, defaults to `learned_with_images`): + The merge strategy to use for the temporal mixing. + switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): + If `True`, switch the spatial and temporal mixing. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + temb_channels: int = 512, + eps: float = 1e-6, + temporal_eps: Optional[float] = None, + merge_factor: float = 0.5, + merge_strategy="learned_with_images", + switch_spatial_to_temporal_mix: bool = False, + ): + super().__init__() + + self.spatial_res_block = ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=eps, + ) + + self.temporal_res_block = TemporalResnetBlock( + in_channels=out_channels if out_channels is not None else in_channels, + out_channels=out_channels if out_channels is not None else in_channels, + temb_channels=temb_channels, + eps=temporal_eps if temporal_eps is not None else eps, + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix, + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ): + num_frames = image_only_indicator.shape[-1] + hidden_states = self.spatial_res_block(hidden_states, temb) + + batch_frames, channels, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states_mix = ( + hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + ) + hidden_states = ( + hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + ) + + if temb is not None: + temb = temb.reshape(batch_size, num_frames, -1) + + hidden_states = self.temporal_res_block(hidden_states, temb) + hidden_states = self.time_mixer( + x_spatial=hidden_states_mix, + x_temporal=hidden_states, + image_only_indicator=image_only_indicator, + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + return hidden_states + + +class AlphaBlender(nn.Module): + r""" + A module to blend spatial and temporal features. + + Parameters: + alpha (`float`): The initial value of the blending factor. + merge_strategy (`str`, *optional*, defaults to `learned_with_images`): + The merge strategy to use for the temporal mixing. + switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): + If `True`, switch the spatial and temporal mixing. + """ + + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + switch_spatial_to_temporal_mix: bool = False, + ): + super().__init__() + self.merge_strategy = merge_strategy + self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE + + if merge_strategy not in self.strategies: + raise ValueError(f"merge_strategy needs to be in {self.strategies}") + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"Unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor: + if self.merge_strategy == "fixed": + alpha = self.mix_factor + + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + + elif self.merge_strategy == "learned_with_images": + if image_only_indicator is None: + raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy") + + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + torch.sigmoid(self.mix_factor)[..., None], + ) + + # (batch, channel, frames, height, width) + if ndims == 5: + alpha = alpha[:, None, :, None, None] + # (batch*frames, height*width, channels) + elif ndims == 3: + alpha = alpha.reshape(-1)[:, None, None] + else: + raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5") + + else: + raise NotImplementedError + + return alpha + + def forward( + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator, x_spatial.ndim) + alpha = alpha.to(x_spatial.dtype) + + if self.switch_spatial_to_temporal_mix: + alpha = 1.0 - alpha + + x = alpha * x_spatial + (1.0 - alpha) * x_temporal + return x diff --git a/diffusers/src/diffusers/models/resnet_flax.py b/diffusers/src/diffusers/models/resnet_flax.py new file mode 100755 index 0000000..f8bb478 --- /dev/null +++ b/diffusers/src/diffusers/models/resnet_flax.py @@ -0,0 +1,124 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class FlaxUpsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), # padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + # hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + in_channels: int + out_channels: int = None + dropout_prob: float = 0.0 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.dropout = nn.Dropout(self.dropout_prob) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(nn.swish(temb)) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual diff --git a/diffusers/src/diffusers/models/transformers/__init__.py b/diffusers/src/diffusers/models/transformers/__init__.py new file mode 100755 index 0000000..99c3eaa --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/__init__.py @@ -0,0 +1,20 @@ +from ...utils import is_torch_available + + +if is_torch_available(): + from .auraflow_transformer_2d import AuraFlowTransformer2DModel + from .cogvideox_transformer_3d import CogVideoXTransformer3DModel + from .cogvideox_transformer_3d_inpainting import CogVideoXTransformer3DInpaintModel + from .dit_transformer_2d import DiTTransformer2DModel + from .dual_transformer_2d import DualTransformer2DModel + from .hunyuan_transformer_2d import HunyuanDiT2DModel + from .latte_transformer_3d import LatteTransformer3DModel + from .lumina_nextdit2d import LuminaNextDiT2DModel + from .pixart_transformer_2d import PixArtTransformer2DModel + from .prior_transformer import PriorTransformer + from .stable_audio_transformer import StableAudioDiTModel + from .t5_film_transformer import T5FilmDecoder + from .transformer_2d import Transformer2DModel + from .transformer_flux import FluxTransformer2DModel + from .transformer_sd3 import SD3Transformer2DModel + from .transformer_temporal import TransformerTemporalModel diff --git a/diffusers/src/diffusers/models/transformers/auraflow_transformer_2d.py b/diffusers/src/diffusers/models/transformers/auraflow_transformer_2d.py new file mode 100755 index 0000000..ad64df0 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -0,0 +1,544 @@ +# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_processor import ( + Attention, + AttentionProcessor, + AuraFlowAttnProcessor2_0, + FusedAuraFlowAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormZero, FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Taken from the original aura flow inference code. +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +# Aura Flow patch embed doesn't use convs for projections. +# Additionally, it uses learned positional embeddings. +class AuraFlowPatchEmbed(nn.Module): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + pos_embed_max_size=None, + ): + super().__init__() + + self.num_patches = (height // patch_size) * (width // patch_size) + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) + self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + + def pe_selection_index_based_on_dim(self, h, w): + # select subset of positional embedding based on H, W, where H, W is size of latent + # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected + # because original input are in flattened format, we have to flatten this 2d grid as well. + h_p, w_p = h // self.patch_size, w // self.patch_size + original_pe_indexes = torch.arange(self.pos_embed.shape[1]) + h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) + original_pe_indexes = original_pe_indexes.view(h_max, w_max) + starth = h_max // 2 - h_p // 2 + endh = starth + h_p + startw = w_max // 2 - w_p // 2 + endw = startw + w_p + original_pe_indexes = original_pe_indexes[starth:endh, startw:endw] + return original_pe_indexes.flatten() + + def forward(self, latent): + batch_size, num_channels, height, width = latent.size() + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + pe_index = self.pe_selection_index_based_on_dim(height, width) + return latent + self.pos_embed[:, pe_index] + + +# Taken from the original Aura flow inference code. +# Our feedforward only has GELU but Aura uses SiLU. +class AuraFlowFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=None) -> None: + super().__init__() + if hidden_dim is None: + hidden_dim = 4 * dim + + final_hidden_dim = int(2 * hidden_dim / 3) + final_hidden_dim = find_multiple(final_hidden_dim, 256) + + self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False) + self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False) + self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.linear_1(x)) * self.linear_2(x) + x = self.out_projection(x) + return x + + +class AuraFlowPreFinalBlock(nn.Module): + def __init__(self, embedding_dim: int, conditioning_embedding_dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = x * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +@maybe_allow_in_graph +class AuraFlowSingleTransformerBlock(nn.Module): + """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT.""" + + def __init__(self, dim, num_attention_heads, attention_head_dim): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + + processor = AuraFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="fp32_layer_norm", + out_dim=dim, + bias=False, + out_bias=False, + processor=processor, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = AuraFlowFeedForward(dim, dim * 4) + + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): + residual = hidden_states + + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + # Attention. + attn_output = self.attn(hidden_states=norm_hidden_states) + + # Process attention outputs for the `hidden_states`. + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(hidden_states) + hidden_states = gate_mlp.unsqueeze(1) * ff_output + hidden_states = residual + hidden_states + + return hidden_states + + +@maybe_allow_in_graph +class AuraFlowJointTransformerBlock(nn.Module): + r""" + Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): + + * QK Norm in the attention blocks + * No bias in the attention blocks + * Most LayerNorms are in FP32 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + is_last (`bool`): Boolean to determine if this is the last block in the model. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + + processor = AuraFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + added_proj_bias=False, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="fp32_layer_norm", + out_dim=dim, + bias=False, + out_bias=False, + processor=processor, + context_pre_only=False, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = AuraFlowFeedForward(dim, dim * 4) + self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff_context = AuraFlowFeedForward(dim, dim * 4) + + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + ): + residual = hidden_states + residual_context = encoder_hidden_states + + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + # Process attention outputs for the `hidden_states`. + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states) + hidden_states = residual + hidden_states + + # Process attention outputs for the `encoder_hidden_states`. + encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output) + encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states) + encoder_hidden_states = residual_context + encoder_hidden_states + + return encoder_hidden_states, hidden_states + + +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use. + num_single_dit_layers (`int`, *optional*, defaults to 4): + The number of layers of Transformer blocks to use. These blocks use concatenated image and text + representations. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. + out_channels (`int`, defaults to 16): Number of output channels. + pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. + """ + + _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 64, + patch_size: int = 2, + in_channels: int = 4, + num_mmdit_layers: int = 4, + num_single_dit_layers: int = 32, + attention_head_dim: int = 256, + num_attention_heads: int = 12, + joint_attention_dim: int = 2048, + caption_projection_dim: int = 3072, + out_channels: int = 4, + pos_embed_max_size: int = 1024, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = AuraFlowPatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.context_embedder = nn.Linear( + self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False + ) + self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True) + self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + + self.joint_transformer_blocks = nn.ModuleList( + [ + AuraFlowJointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_mmdit_layers) + ] + ) + self.single_transformer_blocks = nn.ModuleList( + [ + AuraFlowSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(self.config.num_single_dit_layers) + ] + ) + + self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + # https://arxiv.org/abs/2309.16588 + # prevents artifacts in the attention maps + self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAuraFlowAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + height, width = hidden_states.shape[-2:] + + # Apply patch embedding, timestep embedding, and project the caption embeddings. + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) + temb = self.time_step_proj(temb) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + encoder_hidden_states = torch.cat( + [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 + ) + + # MMDiT blocks. + for index_block, block in enumerate(self.joint_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) + if len(self.single_transformer_blocks) > 0: + encoder_seq_len = encoder_hidden_states.size(1) + combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + combined_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + combined_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) + + hidden_states = combined_hidden_states[:, encoder_seq_len:] + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + out_channels = self.config.out_channels + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py new file mode 100755 index 0000000..acf029c --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -0,0 +1,646 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +import numpy as np + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, FeedForward +from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, CogVideoXAttnProcessor2_0_wo_text, CogVideoXAttnProcessor2_0_resample, FusedCogVideoXAttnProcessor2_0 +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + wo_text: bool = False, + id_pool_resample_learnable: bool = False, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + if wo_text: + self.processor = CogVideoXAttnProcessor2_0_wo_text() + elif id_pool_resample_learnable: + self.processor = CogVideoXAttnProcessor2_0_resample() + else: + self.processor = CogVideoXAttnProcessor2_0() + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=self.processor, + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + resample_mask: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + if attention_kwargs and "prev_hidden_states" in attention_kwargs: + prev_encoder_hidden_states, prev_hidden_states = attention_kwargs["prev_hidden_states"][:, :text_seq_length], attention_kwargs["prev_hidden_states"][:, text_seq_length:] + norm_prev_hidden_states, norm_prev_encoder_hidden_states, prev_gate_msa, prev_enc_gate_msa = self.norm1( + prev_hidden_states, prev_encoder_hidden_states, temb + ) + attention_kwargs["prev_hidden_states"] = torch.cat([norm_prev_encoder_hidden_states, norm_prev_hidden_states], dim=1) + + # attention + if isinstance(self.processor, CogVideoXAttnProcessor2_0_resample): + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + resample_mask=resample_mask, + **attention_kwargs if attention_kwargs else {}, + ) + elif isinstance(self.processor, CogVideoXAttnProcessor2_0): + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + **attention_kwargs if attention_kwargs else {}, + ) + else: + raise ValueError(f"Unsupported processor type: {type(self.processor)}") + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + def forward_wo_text( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + + # norm & modulate + norm_hidden_states, gate_msa = self.norm1.forward_wo_text( + hidden_states, temb + ) + + # attention + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + + # norm & modulate + norm_hidden_states, gate_ff = self.norm2.forward_wo_text( + hidden_states, temb + ) + + # feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output + + return hidden_states + +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + id_pool_resample_learnable: Optional[bool] = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + id_pool_resample_learnable=id_pool_resample_learnable, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + branch_block_samples: Optional[torch.Tensor] = None, + branch_block_masks: Optional[torch.Tensor] = None, + add_first: Optional[bool] = False, + self_guidance_hidden_states: Optional[torch.Tensor] = None, + self_guidance_masks: Optional[torch.Tensor] = None, + return_hidden_states: Optional[bool] = False, + return_resample_mask: Optional[bool] = False, + id_pool_resample_learnable: Optional[bool] = False, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + if self_guidance_masks is not None: + hidden_states, masks = self.patch_embed(encoder_hidden_states, hidden_states, masks=self_guidance_masks) + # hidden_states: (B, Len_t + Len_v, C); masks: (B, Len_v, C) + masks = masks.repeat(1, 1, int(hidden_states.shape[-1] / masks.shape[-1])) + elif branch_block_masks is not None: + hidden_states, masks = self.patch_embed(encoder_hidden_states, hidden_states, masks=branch_block_masks) + masks = masks.repeat(1, 1, int(hidden_states.shape[-1] / masks.shape[-1])) + else: + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # prepare the attention mask + if id_pool_resample_learnable or return_resample_mask: + + if masks is None: + raise ValueError("id_pool_resample needs masks") + image_patch_len = hidden_states.shape[1] // num_frames + total_length = encoder_hidden_states.shape[1] + hidden_states.shape[1] + resample_mask = torch.zeros((batch_size, total_length), + device=hidden_states.device, + dtype=torch.bool) + resample_mask[:, :text_seq_length] = False + resample_mask[:, text_seq_length:] = masks[:, :, 0].bool() + attention_mask = None + else: + attention_mask = None + resample_mask = None + + # 3. Transformer blocks + hidden_states_list = [] + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + attention_mask, + resample_mask, + attention_kwargs, + **ckpt_kwargs, + ) + else: + current_block_kwargs = attention_kwargs.copy() if attention_kwargs else {} + if attention_kwargs and "prev_hidden_states" in attention_kwargs: + layer_states = attention_kwargs["prev_hidden_states"].get(i) + if layer_states is not None: + current_block_kwargs["prev_hidden_states"] = layer_states + current_block_kwargs["prev_clip_weight"] = attention_kwargs["prev_clip_weight"] + prev_resample_mask = attention_kwargs.get("prev_resample_mask") + if prev_resample_mask is not None: + current_block_kwargs["prev_resample_mask"] = prev_resample_mask + + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + resample_mask=resample_mask, + attention_kwargs=current_block_kwargs, + ) + if self_guidance_hidden_states is not None: + hidden_states = torch.where(masks == False, self_guidance_hidden_states[i], hidden_states) + + if branch_block_samples is not None: + if not add_first: + interval_control = len(self.transformer_blocks) / len(branch_block_samples) + interval_control = int(np.ceil(interval_control)) + if branch_block_masks is None: + hidden_states = hidden_states + branch_block_samples[i // interval_control] + else: + hidden_states = torch.where(masks == False, hidden_states + branch_block_samples[i // interval_control], hidden_states) + else: + if i < len(branch_block_samples): + if branch_block_masks is None: + hidden_states = hidden_states + branch_block_samples[i] + else: + hidden_states = torch.where(masks == False, hidden_states + branch_block_samples[i], hidden_states) + + if return_hidden_states: + hidden_states_list.append(torch.cat([encoder_hidden_states, hidden_states], dim=1)) + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + if return_hidden_states: + if return_resample_mask: + return (output, hidden_states_list, resample_mask) + else: + return (output, hidden_states_list) + else: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d_inpainting.py b/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d_inpainting.py new file mode 100755 index 0000000..a1d77d6 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d_inpainting.py @@ -0,0 +1,577 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +import numpy as np + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, FeedForward +from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DInpaintModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels * 2 + 1, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @classmethod + def from_transformer( + cls, + transformer, + num_layers: int = 42, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + load_weights_from_transformer=True, + ): + config = transformer.config + config["num_layers"] = num_layers + config["attention_head_dim"] = attention_head_dim + config["num_attention_heads"] = num_attention_heads + + branch = cls(**config) + + if load_weights_from_transformer: + conv_in_condition_weight=torch.zeros_like(branch.patch_embed.proj.weight) + conv_in_condition_weight[:,:config.in_channels,...]=transformer.patch_embed.proj.weight + conv_in_condition_weight[:,config.in_channels:2*config.in_channels,...]=transformer.patch_embed.proj.weight + branch.patch_embed.proj.weight.data=torch.nn.Parameter(conv_in_condition_weight) + branch.patch_embed.proj.bias.data=transformer.patch_embed.proj.bias + + branch.embedding_dropout.load_state_dict(transformer.embedding_dropout.state_dict()) + branch.time_proj.load_state_dict(transformer.time_proj.state_dict()) + branch.time_embedding.load_state_dict(transformer.time_embedding.state_dict()) + + branch.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + + branch.norm_final.load_state_dict(transformer.norm_final.state_dict()) + branch.norm_out.load_state_dict(transformer.norm_out.state_dict()) + branch.proj_out.load_state_dict(transformer.proj_out.state_dict()) + + return branch + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + branch_block_samples: Optional[torch.Tensor] = None, + add_first: Optional[bool] = False, + self_guidance_hidden_states: Optional[torch.Tensor] = None, + self_guidance_masks: Optional[torch.Tensor] = None, + return_hidden_states: Optional[bool] = False, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + if self_guidance_masks is not None: + hidden_states, masks = self.patch_embed(encoder_hidden_states, hidden_states, masks=self_guidance_masks) + # hidden_states: (B, Len_t + Len_v, C); masks: (B, Len_v, C) + masks = masks.repeat(1, 1, int(hidden_states.shape[-1] / masks.shape[-1])) + else: + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # 3. Transformer blocks + hidden_states_list = [] + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + if self_guidance_hidden_states is not None: + # 1. adding + # hidden_states += self_guidance_hidden_states[i] + # 2. swap based on mask + hidden_states = torch.where(masks == 1, self_guidance_hidden_states[i], hidden_states) + + if return_hidden_states: + hidden_states_list.append(hidden_states) + + if branch_block_samples is not None: + if not add_first: + interval_control = len(self.transformer_blocks) / len(branch_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + branch_block_samples[i // interval_control] + # print(f"adding {i // interval_control + 1} - {len(branch_block_samples)} for {i + 1} - {len(self.transformer_blocks)}") + else: + # print(f"adding first {i} - {len(branch_block_samples)}") + if i < len(branch_block_samples): + + hidden_states = hidden_states + branch_block_samples[i] + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + if return_hidden_states: + return (output, hidden_states_list) + else: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/dit_transformer_2d.py b/diffusers/src/diffusers/models/transformers/dit_transformer_2d.py new file mode 100755 index 0000000..9f89577 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/dit_transformer_2d.py @@ -0,0 +1,240 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 32): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-5): + A small constant added to the denominator in normalization layers to prevent division by zero. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = None, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + attention_bias: bool = True, + sample_size: int = 32, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_zero", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_zero": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict: bool = True, + ): + """ + The [`DiTTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + None, + None, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/dual_transformer_2d.py b/diffusers/src/diffusers/models/transformers/dual_transformer_2d.py new file mode 100755 index 0000000..1c48c4e --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/dual_transformer_2d.py @@ -0,0 +1,156 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from torch import nn + +from ..modeling_outputs import Transformer2DModelOutput +from .transformer_2d import Transformer2DModel + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.Tensor`, *optional*): + Optional attention mask to be applied in Attention. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) diff --git a/diffusers/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/diffusers/src/diffusers/models/transformers/hunyuan_transformer_2d.py new file mode 100755 index 0000000..7f3dab2 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -0,0 +1,578 @@ +# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 +from ..embeddings import ( + HunyuanCombinedTimestepTextSizeStyleEmbedding, + PatchEmbed, + PixArtAlphaTextProjection, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class AdaLayerNormShift(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim) + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) + x = self.norm(x) + shift.unsqueeze(dim=1) + return x + + +@maybe_allow_in_graph +class HunyuanDiTBlock(nn.Module): + r""" + Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and + QKNorm + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of headsto use for multi-head attention. + cross_attention_dim (`int`,*optional*): + The size of the encoder_hidden_states vector for cross attention. + dropout(`float`, *optional*, defaults to 0.0): + The dropout probability to use. + activation_fn (`str`,*optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. . + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, *optional*, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): + The size of the hidden layer in the feed-forward block. Defaults to `None`. + ff_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the feed-forward block. + skip (`bool`, *optional*, defaults to `False`): + Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use normalization in QK calculation. Defaults to `True`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + cross_attention_dim: int = 1024, + dropout=0.0, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-6, + final_dropout: bool = False, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + skip: bool = False, + qk_norm: bool = True, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # NOTE: when new version comes, check norm2 and norm 3 + # 1. Self-Attn + self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + # 3. Feed-forward + self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, ### 0.0 + activation_fn=activation_fn, ### approx GeLU + final_dropout=final_dropout, ### 0.0 + inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) + bias=ff_bias, + ) + + # 4. Skip Connection + if skip: + self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb=None, + skip=None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([hidden_states, skip], dim=-1) + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + attn_output + + # 2. Cross-Attention + hidden_states = hidden_states + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # FFN Layer ### TODO: switch norm2 and norm3 in the state dict + mlp_inputs = self.norm3(hidden_states) + hidden_states = hidden_states + self.ff(mlp_inputs) + + return hidden_states + + +class HunyuanDiT2DModel(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + patch_size (`int`, *optional*): + The size of the patch to use for the input. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. + sample_size (`int`, *optional*): + The width of the latent images. This is fixed during training since it is used to learn a number of + position embeddings. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The number of dimension in the clip text embedding. + hidden_size (`int`, *optional*): + The size of hidden layer in the conditioning embedding layers. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden layer size to the input size. + learn_sigma (`bool`, *optional*, defaults to `True`): + Whether to predict variance. + cross_attention_dim_t5 (`int`, *optional*): + The number dimensions in t5 text embedding. + pooled_projection_dim (`int`, *optional*): + The size of the pooled projection. + text_len (`int`, *optional*): + The length of the clip text embedding. + text_len_t5 (`int`, *optional*): + The length of the T5 text embedding. + use_style_cond_and_image_meta_size (`bool`, *optional*): + Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "gelu-approximate", + sample_size=32, + hidden_size=1152, + num_layers: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = True, + cross_attention_dim: int = 1024, + norm_type: str = "layer_norm", + cross_attention_dim_t5: int = 2048, + pooled_projection_dim: int = 1024, + text_len: int = 77, + text_len_t5: int = 256, + use_style_cond_and_image_meta_size: bool = True, + ): + super().__init__() + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.num_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + + self.text_embedder = PixArtAlphaTextProjection( + in_features=cross_attention_dim_t5, + hidden_size=cross_attention_dim_t5 * 4, + out_features=cross_attention_dim, + act_fn="silu_fp32", + ) + + self.text_embedding_padding = nn.Parameter( + torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) + ) + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + in_channels=in_channels, + embed_dim=hidden_size, + patch_size=patch_size, + pos_embed_type=None, + ) + + self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( + hidden_size, + pooled_projection_dim=pooled_projection_dim, + seq_len=text_len_t5, + cross_attention_dim=cross_attention_dim_t5, + use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, + ) + + # HunyuanDiT Blocks + self.blocks = nn.ModuleList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > num_layers // 2, + ) + for layer in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(HunyuanAttnProcessor2_0()) + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + controlnet_block_samples=None, + return_dict=True, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of `BertModel`. + text_embedding_mask: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of `BertModel`. + encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. + text_embedding_mask_t5: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of T5 Text Encoder. + image_meta_size (torch.Tensor): + Conditional embedding indicate the image sizes + style: torch.Tensor: + Conditional embedding indicate the style + image_rotary_emb (`torch.Tensor`): + The image rotary embeddings to apply on query and key tensors during attention calculation. + return_dict: bool + Whether to return a dictionary. + """ + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) + + temb = self.time_extra_emb( + timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype + ) # [B, D] + + # text projection + batch_size, sequence_length, _ = encoder_hidden_states_t5.shape + encoder_hidden_states_t5 = self.text_embedder( + encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) + ) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) + + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) + text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) + text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() + + encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.config.num_layers // 2: + if controlnet_block_samples is not None: + skip = skips.pop() + controlnet_block_samples.pop() + else: + skip = skips.pop() + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + skip=skip, + ) # (N, L, D) + else: + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) # (N, L, D) + + if layer < (self.config.num_layers // 2 - 1): + skips.append(hidden_states) + + if controlnet_block_samples is not None and len(controlnet_block_samples) != 0: + raise ValueError("The number of controls is not equal to the number of skip connections.") + + # final layer + hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) + hidden_states = self.proj_out(hidden_states) + # (N, L, patch_size ** 2 * out_channels) + + # unpatchify: (N, out_channels, H, W) + patch_size = self.pos_embed.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) diff --git a/diffusers/src/diffusers/models/transformers/latte_transformer_3d.py b/diffusers/src/diffusers/models/transformers/latte_transformer_3d.py new file mode 100755 index 0000000..71d1921 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/latte_transformer_3d.py @@ -0,0 +1,327 @@ +# Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +class LatteTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: + https://github.com/Vchitect/Latte + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input. + out_channels (`int`, *optional*): + The number of channels in the output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + patch_size (`int`, *optional*): + The size of the patches to use in the patch embedding layer. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. During inference, you can denoise for up to but not more steps than + `num_embeds_ada_norm`. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. + caption_channels (`int`, *optional*): + The number of channels in the caption embeddings. + video_length (`int`, *optional*): + The number of frames in the video-like data. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: int = 64, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + caption_channels: int = None, + video_length: int = 16, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + # 1. Define input layers + self.height = sample_size + self.width = sample_size + + interpolation_scale = self.config.sample_size // 64 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 2. Define spatial transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for d in range(num_layers) + ] + ) + + # 3. Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. Latte other blocks. + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + # define temporal positional embedding + temp_pos_embed = get_1d_sincos_pos_embed_from_grid( + inner_dim, torch.arange(0, video_length).unsqueeze(1) + ) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`LatteTransformer3DModel`] forward method. + + Args: + hidden_states shape `(batch size, channel, num_frame, height, width)`: + Input `hidden_states`. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batcheight, sequence_length)` True = keep, False = discard. + * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + enable_temporal_attentions: + (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + # Reshape hidden states + batch_size, channels, num_frame, height, width = hidden_states.shape + # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) + + # Input + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # Prepare text embeddings for spatial block + # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 + encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( + -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] + ) + + # Prepare timesteps for spatial and temporal block + timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) + timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) + + # Spatial and temporal transformer blocks + for i, (spatial_block, temp_block) in enumerate( + zip(self.transformer_blocks, self.temporal_transformer_blocks) + ): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + None, # attention_mask + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + None, # cross_attention_kwargs + None, # class_labels + use_reentrant=False, + ) + else: + hidden_states = spatial_block( + hidden_states, + None, # attention_mask + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + None, # cross_attention_kwargs + None, # class_labels + ) + + if enable_temporal_attentions: + # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size + hidden_states = hidden_states.reshape( + batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] + ).permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) + + if i == 0 and num_frame > 1: + hidden_states = hidden_states + self.temp_pos_embed + + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + None, # cross_attention_kwargs + None, # class_labels + use_reentrant=False, + ) + else: + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + None, # cross_attention_kwargs + None, # class_labels + ) + + # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size + hidden_states = hidden_states.reshape( + batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] + ).permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) + + embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) + ) + output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute( + 0, 2, 1, 3, 4 + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/lumina_nextdit2d.py b/diffusers/src/diffusers/models/transformers/lumina_nextdit2d.py new file mode 100755 index 0000000..d4f5b46 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -0,0 +1,340 @@ +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import LuminaFeedForward +from ..attention_processor import Attention, LuminaAttnProcessor2_0 +from ..embeddings import ( + LuminaCombinedTimestepCaptionEmbedding, + LuminaPatchEmbed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LuminaNextDiTBlock(nn.Module): + """ + A LuminaNextDiTBlock for LuminaNextDiT2DModel. + + Parameters: + dim (`int`): Embedding dimension of the input features. + num_attention_heads (`int`): Number of attention heads. + num_kv_heads (`int`): + Number of attention heads in key and value features (if using GQA), or set to None for the same as query. + multiple_of (`int`): The number of multiple of ffn layer. + ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. + norm_eps (`float`): The eps for norm layer. + qk_norm (`bool`): normalization for query and key. + cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. + norm_elementwise_affine (`bool`, *optional*, defaults to True), + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + cross_attention_dim: int, + norm_elementwise_affine: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + + self.gate = nn.Parameter(torch.zeros([num_attention_heads])) + + # Self-attention + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + self.attn1.to_out = nn.Identity() + + # Cross-attention + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=norm_elementwise_affine, + ) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + temb: torch.Tensor, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Perform a forward pass through the LuminaNextDiTBlock. + + Parameters: + hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. + attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. + image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. + encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. + temb (`torch.Tensor`): Timestep embedding with text prompt embedding. + cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention. + """ + residual = hidden_states + + # Self-attention + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + self_attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=image_rotary_emb, + **cross_attention_kwargs, + ) + + # Cross-attention + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + cross_attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=encoder_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=None, + **cross_attention_kwargs, + ) + cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1) + mixed_attn_output = self_attn_output + cross_attn_output + mixed_attn_output = mixed_attn_output.flatten(-2) + # linear proj + hidden_states = self.attn2.to_out[0](mixed_attn_output) + + hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states) + + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + + return hidden_states + + +class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): + """ + LuminaNextDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): + The size of each patch in the image. This parameter defines the resolution of patches fed into the model. + in_channels (`int`, *optional*, defaults to 4): + The number of input channels for the model. Typically, this matches the number of channels in the input + images. + hidden_size (`int`, *optional*, defaults to 4096): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 8): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + multiple_of (`int`, *optional*, defaults to 256): + A factor that the hidden size should be a multiple of. This can help optimize certain hardware + configurations. + ffn_dim_multiplier (`float`, *optional*): + A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on + the model configuration. + norm_eps (`float`, *optional*, defaults to 1e-5): + A small value added to the denominator for numerical stability in normalization layers. + learn_sigma (`bool`, *optional*, defaults to True): + Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in + predictions. + qk_norm (`bool`, *optional*, defaults to True): + Indicates if the queries and keys in the attention mechanism should be normalized. + cross_attention_dim (`int`, *optional*, defaults to 2048): + The dimensionality of the text embeddings. This parameter defines the size of the text representations used + in the model. + scaling_factor (`float`, *optional*, defaults to 1.0): + A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the + overall scale of the model's operations. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: Optional[int] = 2, + in_channels: Optional[int] = 4, + hidden_size: Optional[int] = 2304, + num_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_kv_heads: Optional[int] = None, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: Optional[float] = 1e-5, + learn_sigma: Optional[bool] = True, + qk_norm: Optional[bool] = True, + cross_attention_dim: Optional[int] = 2048, + scaling_factor: Optional[float] = 1.0, + ) -> None: + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.scaling_factor = scaling_factor + + self.patch_embedder = LuminaPatchEmbed( + patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True + ) + + self.pad_token = nn.Parameter(torch.empty(hidden_size)) + + self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding( + hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim + ) + + self.layers = nn.ModuleList( + [ + LuminaNextDiTBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) + + assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict=True, + ) -> torch.Tensor: + """ + Forward pass of LuminaNextDiT. + + Parameters: + hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). + timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). + encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). + encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). + """ + hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb) + image_rotary_emb = image_rotary_emb.to(hidden_states.device) + + temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask) + + encoder_mask = encoder_mask.bool() + for layer in self.layers: + hidden_states = layer( + hidden_states, + mask, + image_rotary_emb, + encoder_hidden_states, + encoder_mask, + temb=temb, + cross_attention_kwargs=cross_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + + # unpatchify + height_tokens = width_tokens = self.patch_size + height, width = img_size[0] + batch_size = hidden_states.size(0) + sequence_length = (height // height_tokens) * (width // width_tokens) + hidden_states = hidden_states[:, :sequence_length].view( + batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels + ) + output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py b/diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py new file mode 100755 index 0000000..1e5cd57 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -0,0 +1,445 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PixArtTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, + https://arxiv.org/abs/2403.04692). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + cross_attention_dim (int, optional): + The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 128): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. + use_additional_conditions (bool, optional): If we're using additional conditions as inputs. + attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. + caption_channels (int, optional, defaults to None): + Number of channels to use for projecting the caption embeddings. + use_linear_projection (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + num_vector_embeds (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = 8, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = 1152, + attention_bias: bool = True, + sample_size: int = 128, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, + use_additional_conditions: Optional[bool] = None, + caption_channels: Optional[int] = None, + attention_type: Optional[str] = "default", + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_single": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + if use_additional_conditions is None: + if sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + self.caption_projection = None + if self.config.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.config.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + + Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`PixArtTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep (`torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + hidden_states = self.pos_embed(hidden_states) + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + None, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=None, + ) + + # 3. Output + shift, scale = ( + self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/prior_transformer.py b/diffusers/src/diffusers/models/transformers/prior_transformer.py new file mode 100755 index 0000000..fdb6738 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/prior_transformer.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + The output of [`PriorTransformer`]. + + Args: + predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: torch.Tensor + + +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + """ + A Prior Transformer model. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` + num_embeddings (`int`, *optional*, defaults to 77): + The number of embeddings of the model input `hidden_states` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + time_embed_act_fn (`str`, *optional*, defaults to 'silu'): + The activation function to use to create timestep embeddings. + norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before + passing to Transformer blocks. Set it to `None` if normalization is not needed. + embedding_proj_norm_type (`str`, *optional*, defaults to None): + The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not + needed. + encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): + The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if + `encoder_hidden_states` is `None`. + added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. + Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot + product between the text embedding and image embedding as proposed in the unclip paper + https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. + time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. + If None, will be set to `num_attention_heads * attention_head_dim` + embedding_proj_dim (`int`, *optional*, default to None): + The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. + clip_embed_dim (`int`, *optional*, default to None): + The dimension of the output. If None, will be set to `embedding_dim`. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + num_layers: int = 20, + embedding_dim: int = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: float = 0.0, + time_embed_act_fn: str = "silu", + norm_in_type: Optional[str] = None, # layer + embedding_proj_norm_type: Optional[str] = None, # layer + encoder_hid_proj_type: Optional[str] = "linear", # linear + added_emb_type: Optional[str] = "prd", # prd + time_embed_dim: Optional[int] = None, + embedding_proj_dim: Optional[int] = None, + clip_embed_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + + time_embed_dim = time_embed_dim or inner_dim + embedding_proj_dim = embedding_proj_dim or embedding_dim + clip_embed_dim = clip_embed_dim or embedding_dim + + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) + + self.proj_in = nn.Linear(embedding_dim, inner_dim) + + if embedding_proj_norm_type is None: + self.embedding_proj_norm = None + elif embedding_proj_norm_type == "layer": + self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) + else: + raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") + + self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) + + if encoder_hid_proj_type is None: + self.encoder_hidden_states_proj = None + elif encoder_hid_proj_type == "linear": + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + else: + raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") + + self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) + + if added_emb_type == "prd": + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + elif added_emb_type is None: + self.prd_embedding = None + else: + raise ValueError( + f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + + if norm_in_type == "layer": + self.norm_in = nn.LayerNorm(inner_dim) + elif norm_in_type is None: + self.norm_in = None + else: + raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") + + self.norm_out = nn.LayerNorm(inner_dim) + + self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) + + causal_attention_mask = torch.full( + [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 + ) + causal_attention_mask.triu_(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) + + self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def forward( + self, + hidden_states, + timestep: Union[torch.Tensor, float, int], + proj_embedding: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + return_dict: bool = True, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`): + The currently predicted image embeddings. + timestep (`torch.LongTensor`): + Current denoising step. + proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of + a plain tuple. + + Returns: + [`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`: + If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) + + timesteps_projected = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timesteps_projected = timesteps_projected.to(dtype=self.dtype) + time_embeddings = self.time_embedding(timesteps_projected) + + if self.embedding_proj_norm is not None: + proj_embedding = self.embedding_proj_norm(proj_embedding) + + proj_embeddings = self.embedding_proj(proj_embedding) + if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") + + hidden_states = self.proj_in(hidden_states) + + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + + additional_embeds = [] + additional_embeddings_len = 0 + + if encoder_hidden_states is not None: + additional_embeds.append(encoder_hidden_states) + additional_embeddings_len += encoder_hidden_states.shape[1] + + if len(proj_embeddings.shape) == 2: + proj_embeddings = proj_embeddings[:, None, :] + + if len(hidden_states.shape) == 2: + hidden_states = hidden_states[:, None, :] + + additional_embeds = additional_embeds + [ + proj_embeddings, + time_embeddings[:, None, :], + hidden_states, + ] + + if self.prd_embedding is not None: + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + additional_embeds.append(prd_embedding) + + hidden_states = torch.cat( + additional_embeds, + dim=1, + ) + + # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens + additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 + if positional_embeddings.shape[1] < hidden_states.shape[1]: + positional_embeddings = F.pad( + positional_embeddings, + ( + 0, + 0, + additional_embeddings_len, + self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, + ), + value=0.0, + ) + + hidden_states = hidden_states + positional_embeddings + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) + attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) + attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + + if self.norm_in is not None: + hidden_states = self.norm_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + + hidden_states = self.norm_out(hidden_states) + + if self.prd_embedding is not None: + hidden_states = hidden_states[:, -1] + else: + hidden_states = hidden_states[:, additional_embeddings_len:] + + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + + if not return_dict: + return (predicted_image_embedding,) + + return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) + + def post_process_latents(self, prior_latents): + prior_latents = (prior_latents * self.clip_std) + self.clip_mean + return prior_latents diff --git a/diffusers/src/diffusers/models/transformers/stable_audio_transformer.py b/diffusers/src/diffusers/models/transformers/stable_audio_transformer.py new file mode 100755 index 0000000..e3462b5 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/stable_audio_transformer.py @@ -0,0 +1,458 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + StableAudioAttnProcessor2_0, +) +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_2d import Transformer2DModelOutput +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +@maybe_allow_in_graph +class StableAudioDiTBlock(nn.Module): + r""" + Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip + connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for the query states. + num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ): + super().__init__() + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = nn.LayerNorm(dim, norm_eps, True) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_key_value_attention_heads, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn="swiglu", + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + rotary_embedding: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_embedding, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class StableAudioDiTModel(ModelMixin, ConfigMixin): + """ + The Diffusion Transformer model introduced in Stable Audio. + + Reference: https://github.com/Stability-AI/stable-audio-tools + + Parameters: + sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. + in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. + num_key_value_attention_heads (`int`, *optional*, defaults to 12): + The number of heads to use for the key and value states. + out_channels (`int`, defaults to 64): Number of output channels. + cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. + time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. + global_states_input_dim ( `int`, *optional*, defaults to 1536): + Input dimension of the global hidden states projection. + cross_attention_input_dim ( `int`, *optional*, defaults to 768): + Input dimension of the cross-attention projection + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + self.sample_size = sample_size + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.time_proj = StableAudioGaussianFourierProjection( + embedding_size=time_proj_dim // 2, + flip_sin_to_cos=True, + log=False, + set_W_to_weight=False, + ) + + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for i in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) + self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(StableAudioAttnProcessor2_0()) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + timestep: torch.LongTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: bool = True, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`StableAudioDiTModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): + Input `hidden_states`. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): + Global embeddings that will be prepended to the hidden states. + rotary_embedding (`torch.Tensor`): + The rotary embeddings to apply on query and key tensors during attention calculation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention + masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating + the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + global_hidden_states = self.global_proj(global_hidden_states) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) + + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.proj_in(hidden_states) + + # prepend global states to hidden states + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) + if attention_mask is not None: + prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + cross_attention_hidden_states, + encoder_attention_mask, + rotary_embedding, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=cross_attention_hidden_states, + encoder_attention_mask=encoder_attention_mask, + rotary_embedding=rotary_embedding, + ) + + hidden_states = self.proj_out(hidden_states) + + # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) + # remove prepend length that has been added by global hidden states + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/diffusers/src/diffusers/models/transformers/t5_film_transformer.py b/diffusers/src/diffusers/models/transformers/t5_film_transformer.py new file mode 100755 index 0000000..1dea37a --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/t5_film_transformer.py @@ -0,0 +1,436 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..attention_processor import Attention +from ..embeddings import get_timestep_embedding +from ..modeling_utils import ModelMixin + + +class T5FilmDecoder(ModelMixin, ConfigMixin): + r""" + T5 style decoder with FiLM conditioning. + + Args: + input_dims (`int`, *optional*, defaults to `128`): + The number of input dimensions. + targets_length (`int`, *optional*, defaults to `256`): + The length of the targets. + d_model (`int`, *optional*, defaults to `768`): + Size of the input hidden states. + num_layers (`int`, *optional*, defaults to `12`): + The number of `DecoderLayer`'s to use. + num_heads (`int`, *optional*, defaults to `12`): + The number of attention heads to use. + d_kv (`int`, *optional*, defaults to `64`): + Size of the key-value projection vectors. + d_ff (`int`, *optional*, defaults to `2048`): + The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. + dropout_rate (`float`, *optional*, defaults to `0.1`): + Dropout probability. + """ + + @register_to_config + def __init__( + self, + input_dims: int = 128, + targets_length: int = 256, + max_decoder_noise_time: float = 2000.0, + d_model: int = 768, + num_layers: int = 12, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 2048, + dropout_rate: float = 0.1, + ): + super().__init__() + + self.conditioning_emb = nn.Sequential( + nn.Linear(d_model, d_model * 4, bias=False), + nn.SiLU(), + nn.Linear(d_model * 4, d_model * 4, bias=False), + nn.SiLU(), + ) + + self.position_encoding = nn.Embedding(targets_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) + + self.dropout = nn.Dropout(p=dropout_rate) + + self.decoders = nn.ModuleList() + for lyr_num in range(num_layers): + # FiLM conditional T5 decoder + lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) + self.decoders.append(lyr) + + self.decoder_norm = T5LayerNorm(d_model) + + self.post_dropout = nn.Dropout(p=dropout_rate) + self.spec_out = nn.Linear(d_model, input_dims, bias=False) + + def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor: + mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) + return mask.unsqueeze(-3) + + def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): + batch, _, _ = decoder_input_tokens.shape + assert decoder_noise_time.shape == (batch,) + + # decoder_noise_time is in [0, 1), so rescale to expected timing range. + time_steps = get_timestep_embedding( + decoder_noise_time * self.config.max_decoder_noise_time, + embedding_dim=self.config.d_model, + max_period=self.config.max_decoder_noise_time, + ).to(dtype=self.dtype) + + conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) + + assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) + + seq_length = decoder_input_tokens.shape[1] + + # If we want to use relative positions for audio context, we can just offset + # this sequence by the length of encodings_and_masks. + decoder_positions = torch.broadcast_to( + torch.arange(seq_length, device=decoder_input_tokens.device), + (batch, seq_length), + ) + + position_encodings = self.position_encoding(decoder_positions) + + inputs = self.continuous_inputs_projection(decoder_input_tokens) + inputs += position_encodings + y = self.dropout(inputs) + + # decoder: No padding present. + decoder_mask = torch.ones( + decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype + ) + + # Translate encoding masks to encoder-decoder masks. + encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] + + # cross attend style: concat encodings + encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) + encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) + + for lyr in self.decoders: + y = lyr( + y, + conditioning_emb=conditioning_emb, + encoder_hidden_states=encoded, + encoder_attention_mask=encoder_decoder_mask, + )[0] + + y = self.decoder_norm(y) + y = self.post_dropout(y) + + spec_out = self.spec_out(y) + return spec_out + + +class DecoderLayer(nn.Module): + r""" + T5 decoder layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__( + self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 + ): + super().__init__() + self.layer = nn.ModuleList() + + # cond self attention: layer 0 + self.layer.append( + T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) + ) + + # cross attention: layer 1 + self.layer.append( + T5LayerCrossAttention( + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + dropout_rate=dropout_rate, + layer_norm_epsilon=layer_norm_epsilon, + ) + ) + + # Film Cond MLP + dropout: last layer + self.layer.append( + T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) + ) + + def forward( + self, + hidden_states: torch.Tensor, + conditioning_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_decoder_position_bias=None, + ) -> Tuple[torch.Tensor]: + hidden_states = self.layer[0]( + hidden_states, + conditioning_emb=conditioning_emb, + attention_mask=attention_mask, + ) + + if encoder_hidden_states is not None: + encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( + encoder_hidden_states.dtype + ) + + hidden_states = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_extended_attention_mask, + ) + + # Apply Film Conditional Feed Forward layer + hidden_states = self.layer[-1](hidden_states, conditioning_emb) + + return (hidden_states,) + + +class T5LayerSelfAttentionCond(nn.Module): + r""" + T5 style self-attention layer with conditioning. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): + super().__init__() + self.layer_norm = T5LayerNorm(d_model) + self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + conditioning_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # pre_self_attention_layer_norm + normed_hidden_states = self.layer_norm(hidden_states) + + if conditioning_emb is not None: + normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) + + # Self-attention block + attention_output = self.attention(normed_hidden_states) + + hidden_states = hidden_states + self.dropout(attention_output) + + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + r""" + T5 style cross-attention layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + encoder_hidden_states=key_value_states, + attention_mask=attention_mask.squeeze(1), + ) + layer_output = hidden_states + self.dropout(attention_output) + return layer_output + + +class T5LayerFFCond(nn.Module): + r""" + T5 style feed-forward conditional layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) + self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + if conditioning_emb is not None: + forwarded_states = self.film(forwarded_states, conditioning_emb) + + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + r""" + T5 style feed-forward layer with gated activations and dropout. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float): + super().__init__() + self.wi_0 = nn.Linear(d_model, d_ff, bias=False) + self.wi_1 = nn.Linear(d_model, d_ff, bias=False) + self.wo = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout_rate) + self.act = NewGELUActivation() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerNorm(nn.Module): + r""" + T5 style layer normalization module. + + Args: + hidden_size (`int`): + Size of the input hidden states. + eps (`float`, `optional`, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class T5FiLMLayer(nn.Module): + """ + T5 style FiLM Layer. + + Args: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + """ + + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) + + def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor: + emb = self.scale_bias(conditioning_emb) + scale, shift = torch.chunk(emb, 2, -1) + x = x * (1 + scale) + shift + return x diff --git a/diffusers/src/diffusers/models/transformers/transformer_2d.py b/diffusers/src/diffusers/models/transformers/transformer_2d.py new file mode 100755 index 0000000..c7c19e4 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/transformer_2d.py @@ -0,0 +1,566 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import LegacyConfigMixin, register_to_config +from ...utils import deprecate, is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import LegacyModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Transformer2DModelOutput(Transformer2DModelOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." + deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + interpolation_scale: float = None, + use_additional_conditions: Optional[bool] = None, + ): + super().__init__() + + # Validate inputs. + if patch_size is not None: + if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # Set some common variables used across the board. + self.use_linear_projection = use_linear_projection + self.interpolation_scale = interpolation_scale + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + if use_additional_conditions is None: + if norm_type == "ada_norm_single" and sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + # 2. Initialize the right blocks. + # These functions follow a common structure: + # a. Initialize the input blocks. b. Initialize the transformer blocks. + # c. Initialize the output blocks and other projection blocks when necessary. + if self.is_input_continuous: + self._init_continuous_input(norm_type=norm_type) + elif self.is_input_vectorized: + self._init_vectorized_inputs(norm_type=norm_type) + elif self.is_input_patches: + self._init_patched_inputs(norm_type=norm_type) + + def _init_continuous_input(self, norm_type): + self.norm = torch.nn.GroupNorm( + num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True + ) + if self.use_linear_projection: + self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) + else: + self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.use_linear_projection: + self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) + else: + self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) + + def _init_vectorized_inputs(self, norm_type): + assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert ( + self.config.num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + + self.height = self.config.sample_size + self.width = self.config.sample_size + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + self.norm_out = nn.LayerNorm(self.inner_dim) + self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1) + + def _init_patched_inputs(self, norm_type): + assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.config.norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + elif self.config.norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + + # PixArt-Alpha blocks. + self.adaln_single = None + if self.config.norm_type == "ada_norm_single": + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + self.caption_projection = None + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch_size, _, height, width = hidden_states.shape + residual = hidden_states + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs + ) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + output = self._get_output_for_continuous_inputs( + hidden_states=hidden_states, + residual=residual, + batch_size=batch_size, + height=height, + width=width, + inner_dim=inner_dim, + ) + elif self.is_input_vectorized: + output = self._get_output_for_vectorized_inputs(hidden_states) + elif self.is_input_patches: + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + class_labels=class_labels, + embedded_timestep=embedded_timestep, + height=height, + width=width, + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def _operate_on_continuous_inputs(self, hidden_states): + batch, _, height, width = hidden_states.shape + hidden_states = self.norm(hidden_states) + + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + return hidden_states, inner_dim + + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs): + batch_size = hidden_states.shape[0] + hidden_states = self.pos_embed(hidden_states) + embedded_timestep = None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + return hidden_states, encoder_hidden_states, timestep, embedded_timestep + + def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + return output + + def _get_output_for_vectorized_inputs(self, hidden_states): + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + return output + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None + ): + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + return output diff --git a/diffusers/src/diffusers/models/transformers/transformer_flux.py b/diffusers/src/diffusers/models/transformers/transformer_flux.py new file mode 100755 index 0000000..79bfecc --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/transformer_flux.py @@ -0,0 +1,580 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention import FeedForward +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + FluxAttnProcessor2_0, + FluxAttnProcessor2_0_NPU, + FusedFluxAttnProcessor2_0, +) +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.import_utils import is_torch_npu_available +from ...utils.torch_utils import maybe_allow_in_graph +from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + if is_torch_npu_available(): + processor = FluxAttnProcessor2_0_NPU() + else: + processor = FluxAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + ): + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + + self.norm1_context = AdaLayerNormZero(dim) + + if hasattr(F, "scaled_dot_product_attention"): + processor = FluxAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + ): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + ) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedFluxAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/diffusers/src/diffusers/models/transformers/transformer_sd3.py b/diffusers/src/diffusers/models/transformers/transformer_sd3.py new file mode 100755 index 0000000..9376c91 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/transformer_sd3.py @@ -0,0 +1,367 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention import JointTransformerBlock +from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + out_channels (`int`, defaults to 16): Number of output channels. + + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, # hard-code for now. + ) + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + ) + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + context_pre_only=i == num_layers - 1, + ) + for i in range(self.config.num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + block_controlnet_hidden_states: List = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + # controlnet residual + if block_controlnet_hidden_states is not None and block.context_pre_only is False: + interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) + hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/transformers/transformer_temporal.py b/diffusers/src/diffusers/models/transformers/transformer_temporal.py new file mode 100755 index 0000000..c0c5467 --- /dev/null +++ b/diffusers/src/diffusers/models/transformers/transformer_temporal.py @@ -0,0 +1,381 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..resnet import AlphaBlender + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.Tensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: torch.LongTensor = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> TransformerTemporalModelOutput: + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] + instead of a plain tuple. + + Returns: + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) + + +class TransformerSpatioTemporalModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + out_channels (`int`, *optional*): + The number of channels in the output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 320, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + # 2. Define input layers + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for d in range(num_layers) + ] + ) + + time_mix_inner_dim = inner_dim + self.temporal_transformer_blocks = nn.ModuleList( + [ + TemporalBasicTransformerBlock( + inner_dim, + time_mix_inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + time_embed_dim = in_channels * 4 + self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) + self.time_proj = Timesteps(in_channels, True, 0) + self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): + Input hidden_states. + num_frames (`int`): + The number of frames to be processed per batch. This is used to reshape the hidden states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): + A tensor indicating whether the input contains only images. 1 indicates that the input contains only + images, 0 indicates that the input contains video frames. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] + instead of a plain tuple. + + Returns: + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, _, height, width = hidden_states.shape + num_frames = image_only_indicator.shape[-1] + batch_size = batch_frames // num_frames + + time_context = encoder_hidden_states + time_context_first_timestep = time_context[None, :].reshape( + batch_size, num_frames, -1, time_context.shape[-1] + )[:, 0] + time_context = time_context_first_timestep[:, None].broadcast_to( + batch_size, height * width, time_context.shape[-2], time_context.shape[-1] + ) + time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1]) + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + num_frames_emb = torch.arange(num_frames, device=hidden_states.device) + num_frames_emb = num_frames_emb.repeat(batch_size, 1) + num_frames_emb = num_frames_emb.reshape(-1) + t_emb = self.time_proj(num_frames_emb) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + # 2. Blocks + for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + None, + encoder_hidden_states, + None, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states_mix = hidden_states + hidden_states_mix = hidden_states_mix + emb + + hidden_states_mix = temporal_block( + hidden_states_mix, + num_frames=num_frames, + encoder_hidden_states=time_context, + ) + hidden_states = self.time_mixer( + x_spatial=hidden_states, + x_temporal=hidden_states_mix, + image_only_indicator=image_only_indicator, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/diffusers/src/diffusers/models/unets/__init__.py b/diffusers/src/diffusers/models/unets/__init__.py new file mode 100755 index 0000000..9ef04fb --- /dev/null +++ b/diffusers/src/diffusers/models/unets/__init__.py @@ -0,0 +1,18 @@ +from ...utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .unet_1d import UNet1DModel + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel + from .unet_i2vgen_xl import I2VGenXLUNet + from .unet_kandinsky3 import Kandinsky3UNet + from .unet_motion_model import MotionAdapter, UNetMotionModel + from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + from .unet_stable_cascade import StableCascadeUNet + from .uvit_2d import UVit2DModel + + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel diff --git a/diffusers/src/diffusers/models/unets/unet_1d.py b/diffusers/src/diffusers/models/unets/unet_1d.py new file mode 100755 index 0000000..8efabd9 --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_1d.py @@ -0,0 +1,255 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + The output of [`UNet1DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.Tensor + + +class UNet1DModel(ModelMixin, ConfigMixin): + r""" + A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + extra_in_channels (`int`, *optional*, defaults to 0): + Number of additional channels to be added to the input of the first down block. Useful for cases where the + input data has more channels than what the model was initially designed for. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): + Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. + act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. + layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. + downsample_each_block (`int`, *optional*, defaults to `False`): + Experimental feature for using a UNet without upsampling. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 65536, + sample_rate: Optional[int] = None, + in_channels: int = 2, + out_channels: int = 2, + extra_in_channels: int = 0, + time_embedding_type: str = "fourier", + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, + freq_shift: float = 0.0, + down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", + out_block_type: str = None, + block_out_channels: Tuple[int] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, + ): + super().__init__() + self.sample_size = sample_size + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) + timestep_input_dim = block_out_channels[0] + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + + # down + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet1DOutput, Tuple]: + r""" + The [`UNet1DModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unets.unet_1d.UNet1DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) + + # 2. down + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples + + # 3. mid + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + + # 4. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) + + if not return_dict: + return (sample,) + + return UNet1DOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_1d_blocks.py b/diffusers/src/diffusers/models/unets/unet_1d_blocks.py new file mode 100755 index 0000000..8fc27e9 --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_1d_blocks.py @@ -0,0 +1,702 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..activations import get_activation +from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + conv_shortcut: bool = False, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_downsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, embed_dim: int): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + embed_dim: int, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity: Optional[str] = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv=True) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + self.final_conv1d_act = get_activation(act_fn) + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + get_activation(act_fn), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv1d(hidden_states, weight, stride=2) + + +class Upsample1d(nn.Module): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + + self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) + + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores, dim=-1) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + + output = hidden_states + residual + + return output + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + + output = hidden_states + residual + return output + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.down(hidden_states) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.down(hidden_states) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states + + +DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] +MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] +OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] +UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, +) -> DownBlockType: + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool +) -> UpBlockType: + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + num_layers: int, + in_channels: int, + mid_channels: int, + out_channels: int, + embed_dim: int, + add_downsample: bool, +) -> MidBlockType: + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block( + *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int +) -> Optional[OutBlockType]: + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) + return None diff --git a/diffusers/src/diffusers/models/unets/unet_2d.py b/diffusers/src/diffusers/models/unets/unet_2d.py new file mode 100755 index 0000000..5972505 --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_2d.py @@ -0,0 +1,346 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + The output of [`UNet2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.Tensor + + +class UNet2DModel(ModelMixin, ConfigMixin): + r""" + A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - + 1)`. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): + Tuple of downsample block types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + downsample_type (`str`, *optional*, defaults to `conv`): + The downsample type for downsampling layers. Choose between "conv" and "resnet" + upsample_type (`str`, *optional*, defaults to `conv`): + The upsample type for upsampling layers. Choose between "conv" and "resnet" + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. + attn_norm_num_groups (`int`, *optional*, defaults to `None`): + If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the + given number of groups. If left as `None`, the group norm layer will only be created if + `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class + conditioning with `class_embed_type` equal to `None`. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + dropout: float = 0.0, + act_fn: str = "silu", + attention_head_dim: Optional[int] = 8, + norm_num_groups: int = 32, + attn_norm_num_groups: Optional[int] = None, + norm_eps: float = 1e-5, + resnet_time_scale_shift: str = "default", + add_attention: bool = True, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + num_train_timesteps: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + elif time_embedding_type == "learned": + self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + attn_groups=attn_norm_num_groups, + add_attention=add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: + r""" + The [`UNet2DModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when doing class conditioning") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + elif self.class_embedding is None and class_labels is not None: + raise ValueError("class_embedding needs to be initialized in order to use class conditioning") + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_2d_blocks.py b/diffusers/src/diffusers/models/unets/unet_2d_blocks.py new file mode 100755 index 0000000..93a0a82 --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_2d_blocks.py @@ -0,0 +1,3740 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...utils import deprecate, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..activations import get_activation +from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from ..normalization import AdaGroupNorm +from ..resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + return UNetMidBlock2DSimpleCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + return UNetMidBlock2D( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + num_layers=0, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + return None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class AutoencoderTinyBlock(nn.Module): + """ + Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU + blocks. + + Args: + in_channels (`int`): The number of input channels. + out_channels (`int`): The number of output channels. + act_fn (`str`): + ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. + + Returns: + `torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to + `out_channels`. + """ + + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels + else nn.Identity() + ) + self.fuse = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fuse(self.conv(x) + self.skip(x)) + + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + downsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample_type == "conv": + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + additional_residuals: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + skip_sample: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + skip_sample: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample: bool = True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + upsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample_type == "conv": + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + skip_sample=None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + upsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + skip_sample=None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim: int = 1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + attention_bias (`bool`, *optional*, defaults to `False`): + Configure if the attention layers should contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to upcast the attention computation to `float32`. + temb_channels (`int`, *optional*, defaults to 768): + The number of channels in the token embedding. + add_self_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to add self-attention to the block. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + group_size (`int`, *optional*, defaults to 32): + The number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor: + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor: + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/diffusers/src/diffusers/models/unets/unet_2d_blocks_flax.py b/diffusers/src/diffusers/models/unets/unet_2d_blocks_flax.py new file mode 100755 index 0000000..a4585db --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_2d_blocks_flax.py @@ -0,0 +1,400 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +import jax.numpy as jnp + +from ..attention_flax import FlaxTransformer2DModel +from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D + + +class FlaxCrossAttnDownBlock2D(nn.Module): + r""" + Cross Attention 2D Downsizing block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + add_downsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxTransformer2DModel( + in_channels=self.out_channels, + n_heads=self.num_attention_heads, + d_head=self.out_channels // self.num_attention_heads, + depth=self.transformer_layers_per_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxDownBlock2D(nn.Module): + r""" + Flax 2D downsizing block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, deterministic=True): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxCrossAttnUpBlock2D(nn.Module): + r""" + Cross Attention 2D Upsampling block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + add_upsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxTransformer2DModel( + in_channels=self.out_channels, + n_heads=self.num_attention_heads, + d_head=self.out_channels // self.num_attention_heads, + depth=self.transformer_layers_per_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUpBlock2D(nn.Module): + r""" + Flax 2D upsampling block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + prev_output_channel (:obj:`int`): + Output channels from the previous block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2DCrossAttn(nn.Module): + r""" + Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + use_linear_projection: bool = False + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 + + def setup(self): + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxTransformer2DModel( + in_channels=self.in_channels, + n_heads=self.num_attention_heads, + d_head=self.in_channels // self.num_attention_heads, + depth=self.transformer_layers_per_block, + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + return hidden_states diff --git a/diffusers/src/diffusers/models/unets/unet_2d_condition.py b/diffusers/src/diffusers/models/unets/unet_2d_condition.py new file mode 100755 index 0000000..4f55df3 --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_2d_condition.py @@ -0,0 +1,1312 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ..activations import get_activation +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from ..modeling_utils import ModelMixin +from .unet_2d_blocks import ( + get_down_block, + get_mid_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNet2DConditionModel( + ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: Union[int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError( + f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." + ) + + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: + encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) + + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_2d_condition_flax.py b/diffusers/src/diffusers/models/unets/unet_2d_condition_flax.py new file mode 100755 index 0000000..edbbcba --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -0,0 +1,454 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...utils import BaseOutput +from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from ..modeling_flax_utils import FlaxModelMixin +from .unet_2d_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxCrossAttnUpBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, + FlaxUpBlock2D, +) + + +@flax.struct.dataclass +class FlaxUNet2DConditionOutput(BaseOutput): + """ + The output of [`FlaxUNet2DConditionModel`]. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + + +@flax_register_to_config +class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. + + Inherent JAX features such as the following are supported: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): + The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer + is skipped. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `Tuple[int]`, *optional*): + The number of attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + """ + + sample_size: int = 32 + in_channels: int = 4 + out_channels: int = 4 + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn" + only_cross_attention: Union[bool, Tuple[bool]] = False + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: Union[int, Tuple[int, ...]] = 8 + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None + cross_attention_dim: int = 1280 + dropout: float = 0.0 + use_linear_projection: bool = False + dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True + freq_shift: int = 0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1 + addition_embed_type: Optional[str] = None + addition_time_embed_dim: Optional[int] = None + addition_embed_type_num_heads: int = 64 + projection_class_embeddings_input_dim: Optional[int] = None + + def init_weights(self, rng: jax.Array) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + added_cond_kwargs = None + if self.addition_embed_type == "text_time": + # we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner + # or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim` + is_refiner = ( + 5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim + == self.config.projection_class_embeddings_input_dim + ) + num_micro_conditions = 5 if is_refiner else 6 + + text_embeds_dim = self.config.projection_class_embeddings_input_dim - ( + num_micro_conditions * self.config.addition_time_embed_dim + ) + + time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim + time_ids_dims = time_ids_channels // self.addition_time_embed_dim + added_cond_kwargs = { + "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32), + "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32), + } + return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] + + def setup(self) -> None: + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + if self.num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = self.num_attention_heads or self.attention_head_dim + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(self.down_block_types) + + # transformer layers per block + transformer_layers_per_block = self.transformer_layers_per_block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) + + # addition embed types + if self.addition_embed_type is None: + self.add_embedding = None + elif self.addition_embed_type == "text_time": + if self.addition_time_embed_dim is None: + raise ValueError( + f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None" + ) + self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) + self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + else: + raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") + + # down + down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=num_attention_heads[i], + add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # mid + if self.config.mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=self.dropout, + num_attention_heads=num_attention_heads[-1], + transformer_layers_per_block=transformer_layers_per_block[-1], + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + elif self.config.mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}") + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + for i, up_block_type in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "CrossAttnUpBlock2D": + up_block = FlaxCrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + num_attention_heads=reversed_num_attention_heads[i], + add_upsample=not is_final_block, + dropout=self.dropout, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + else: + up_block = FlaxUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + add_upsample=not is_final_block, + dropout=self.dropout, + dtype=self.dtype, + ) + + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = up_blocks + + # out + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__( + self, + sample: jnp.ndarray, + timesteps: Union[jnp.ndarray, float, int], + encoder_hidden_states: jnp.ndarray, + added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, + down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None, + mid_block_additional_residual: Optional[jnp.ndarray] = None, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]: + r""" + Args: + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of + a plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # additional embeddings + aug_emb = None + if self.addition_embed_type == "text_time": + if added_cond_kwargs is None: + raise ValueError( + f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if text_embeds is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + if time_ids is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + # compute time embeds + time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256) + time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1)) + add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1) + aug_emb = self.add_embedding(add_embeds) + + t_emb = t_emb + aug_emb if aug_emb is not None else t_emb + + # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] + if isinstance(up_block, FlaxCrossAttnUpBlock2D): + sample = up_block( + sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + deterministic=not train, + ) + else: + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = nn.silu(sample) + sample = self.conv_out(sample) + sample = jnp.transpose(sample, (0, 3, 1, 2)) + + if not return_dict: + return (sample,) + + return FlaxUNet2DConditionOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_3d_blocks.py b/diffusers/src/diffusers/models/unets/unet_3d_blocks.py new file mode 100755 index 0000000..8b472a8 --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_3d_blocks.py @@ -0,0 +1,1538 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from ...utils import deprecate, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..attention import Attention +from ..resnet import ( + Downsample2D, + ResnetBlock2D, + SpatioTemporalResBlock, + TemporalConvLayer, + Upsample2D, +) +from ..transformers.transformer_2d import Transformer2DModel +from ..transformers.transformer_temporal import ( + TransformerSpatioTemporalModel, + TransformerTemporalModel, +) +from .unet_motion_model import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DownBlockMotion(DownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead." + deprecate("DownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead." + deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UpBlockMotion(UpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead." + deprecate("UpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead." + deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead." + deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + dropout: float = 0.0, +) -> Union[ + "DownBlock3D", + "CrossAttnDownBlock3D", + "DownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", +]: + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + dropout=dropout, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + dropout=dropout, + ) + elif down_block_type == "DownBlockSpatioTemporal": + # added for SDV + return DownBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") + return CrossAttnDownBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) + + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resolution_idx: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + dropout: float = 0.0, +) -> Union[ + "UpBlock3D", + "CrossAttnUpBlock3D", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", +]: + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + dropout=dropout, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + dropout=dropout, + ) + elif up_block_type == "UpBlockSpatioTemporal": + # added for SDV + return UpBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") + return CrossAttnUpBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_upsample=add_upsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resolution_idx=resolution_idx, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> torch.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + num_frames: int = 1, + ) -> torch.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class MidBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + attention_head_dim: int = 512, + num_layers: int = 1, + upcast_attention: bool = False, + ): + super().__init__() + + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=1e-6, + upcast_attention=upcast_attention, + norm_num_groups=32, + bias=True, + residual_connection=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + image_only_indicator: torch.Tensor, + ): + hidden_states = self.resnets[0]( + hidden_states, + image_only_indicator=image_only_indicator, + ) + for resnet, attn in zip(self.resnets[1:], self.attentions): + hidden_states = attn(hidden_states) + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class UpBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.Tensor, + image_only_indicator: torch.Tensor, + ) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ] + attentions = [] + + for i in range(num_layers): + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0]( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class DownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-6, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=1, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + for resnet, attn in blocks: + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + resnet_eps: float = 1e-6, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states diff --git a/diffusers/src/diffusers/models/unets/unet_3d_condition.py b/diffusers/src/diffusers/models/unets/unet_3d_condition.py new file mode 100755 index 0000000..3081fdc --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_3d_condition.py @@ -0,0 +1,748 @@ +# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2024 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..activations import get_activation +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformers.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + The output of [`UNet3DConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): The number of attention heads. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + time_cond_proj_dim: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise NotImplementedError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + cond_proj_dim=time_cond_proj_dim, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + norm_num_groups=norm_num_groups, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = get_activation("silu") + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: + r""" + The [`UNet3DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + + Returns: + [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_i2vgen_xl.py b/diffusers/src/diffusers/models/unets/unet_i2vgen_xl.py new file mode 100755 index 0000000..6ab3a57 --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -0,0 +1,728 @@ +# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import logging +from ..activations import get_activation +from ..attention import Attention, FeedForward +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformers.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .unet_3d_condition import UNet3DConditionOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class I2VGenXLTransformerTemporalEncoder(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "geglu", + upcast_attention: bool = False, + ff_inner_dim: Optional[int] = None, + dropout: int = 0.0, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=True, + ) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + ff_output = self.ff(hidden_states) + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and + returns a sample-shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim. + num_attention_heads (`int`, *optional*): The number of attention heads. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + norm_num_groups: Optional[int] = 32, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + ): + super().__init__() + + # When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence + # of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This + # is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below. + # This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it + # without running proper depcrecation cycles for the {down,mid,up} blocks which are a + # part of the public API. + num_attention_heads = attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d(in_channels + in_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=num_attention_heads, + in_channels=block_out_channels[0], + num_layers=1, + norm_num_groups=norm_num_groups, + ) + + # image embedding + self.image_latents_proj_in = nn.Sequential( + nn.Conv2d(4, in_channels * 4, 3, padding=1), + nn.SiLU(), + nn.Conv2d(in_channels * 4, in_channels * 4, 3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(in_channels * 4, in_channels, 3, stride=1, padding=1), + ) + self.image_latents_temporal_encoder = I2VGenXLTransformerTemporalEncoder( + dim=in_channels, + num_attention_heads=2, + ff_inner_dim=in_channels * 4, + attention_head_dim=in_channels, + activation_fn="gelu", + ) + self.image_latents_context_embedding = nn.Sequential( + nn.Conv2d(4, in_channels * 8, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((32, 32)), + nn.Conv2d(in_channels * 8, in_channels * 16, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(in_channels * 16, cross_attention_dim, 3, stride=2, padding=1), + ) + + # other embeddings -- time, context, fps, etc. + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn="silu") + self.context_embedding = nn.Sequential( + nn.Linear(cross_attention_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, cross_attention_dim * in_channels), + ) + self.fps_embedding = nn.Sequential( + nn.Linear(timestep_input_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # blocks + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-05, + resnet_act_fn="silu", + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=1, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=1e-05, + resnet_act_fn="silu", + output_scale_factor=1, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-05, + resnet_act_fn="silu", + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-05) + self.conv_act = get_activation("silu") + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + fps: torch.Tensor, + image_latents: torch.Tensor, + image_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: + r""" + The [`I2VGenXLUNet`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition". + image_latents (`torch.Tensor`): Image encodings from the VAE. + image_embeddings (`torch.Tensor`): + Projection embeddings of the conditioning image computed with a vision encoder. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + batch_size, channels, num_frames, height, width = sample.shape + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + t_emb = self.time_embedding(t_emb, timestep_cond) + + # 2. FPS + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + fps = fps.expand(fps.shape[0]) + fps_emb = self.fps_embedding(self.time_proj(fps).to(dtype=self.dtype)) + + # 3. time + FPS embeddings. + emb = t_emb + fps_emb + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + # 4. context embeddings. + # The context embeddings consist of both text embeddings from the input prompt + # AND the image embeddings from the input image. For images, both VAE encodings + # and the CLIP image embeddings are incorporated. + # So the final `context_embeddings` becomes the query for cross-attention. + context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim) + context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1) + + image_latents_for_context_embds = image_latents[:, :, :1, :] + image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape( + image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], + image_latents_for_context_embds.shape[1], + image_latents_for_context_embds.shape[3], + image_latents_for_context_embds.shape[4], + ) + image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) + + _batch_size, _channels, _height, _width = image_latents_context_embs.shape + image_latents_context_embs = image_latents_context_embs.permute(0, 2, 3, 1).reshape( + _batch_size, _height * _width, _channels + ) + context_emb = torch.cat([context_emb, image_latents_context_embs], dim=1) + + image_emb = self.context_embedding(image_embeddings) + image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim) + context_emb = torch.cat([context_emb, image_emb], dim=1) + context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) + + image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( + image_latents.shape[0] * image_latents.shape[2], + image_latents.shape[1], + image_latents.shape[3], + image_latents.shape[4], + ) + image_latents = self.image_latents_proj_in(image_latents) + image_latents = ( + image_latents[None, :] + .reshape(batch_size, num_frames, channels, height, width) + .permute(0, 3, 4, 1, 2) + .reshape(batch_size * height * width, num_frames, channels) + ) + image_latents = self.image_latents_temporal_encoder(image_latents) + image_latents = image_latents.reshape(batch_size, height, width, num_frames, channels).permute(0, 4, 3, 1, 2) + + # 5. pre-process + sample = torch.cat([sample, image_latents], dim=1) + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 6. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=context_emb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + # 7. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=context_emb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + # 8. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=context_emb, + upsample_size=upsample_size, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 9. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_kandinsky3.py b/diffusers/src/diffusers/models/unets/unet_kandinsky3.py new file mode 100755 index 0000000..f611e7d --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_kandinsky3.py @@ -0,0 +1,535 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ..attention_processor import Attention, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Kandinsky3UNetOutput(BaseOutput): + sample: torch.Tensor = None + + +class Kandinsky3EncoderProj(nn.Module): + def __init__(self, encoder_hid_dim, cross_attention_dim): + super().__init__() + self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False) + self.projection_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, x): + x = self.projection_linear(x) + x = self.projection_norm(x) + return x + + +class Kandinsky3UNet(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 4, + time_embedding_dim: int = 1536, + groups: int = 32, + attention_head_dim: int = 64, + layers_per_block: Union[int, Tuple[int]] = 3, + block_out_channels: Tuple[int] = (384, 768, 1536, 3072), + cross_attention_dim: Union[int, Tuple[int]] = 4096, + encoder_hid_dim: int = 4096, + ): + super().__init__() + + # TODO(Yiyi): Give better name and put into config for the following 4 parameters + expansion_ratio = 4 + compression_ratio = 2 + add_cross_attention = (False, True, True, True) + add_self_attention = (False, True, True, True) + + out_channels = in_channels + init_channels = block_out_channels[0] // 2 + self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) + + self.time_embedding = TimestepEmbedding( + init_channels, + time_embedding_dim, + ) + + self.add_time_condition = Kandinsky3AttentionPooling( + time_embedding_dim, cross_attention_dim, attention_head_dim + ) + + self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1) + + self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim) + + hidden_dims = [init_channels] + list(block_out_channels) + in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) + text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention] + num_blocks = len(block_out_channels) * [layers_per_block] + layer_params = [num_blocks, text_dims, add_self_attention] + rev_layer_params = map(reversed, layer_params) + + cat_dims = [] + self.num_levels = len(in_out_dims) + self.down_blocks = nn.ModuleList([]) + for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(in_out_dims, *layer_params) + ): + down_sample = level != (self.num_levels - 1) + cat_dims.append(out_dim if level != (self.num_levels - 1) else 0) + self.down_blocks.append( + Kandinsky3DownSampleBlock( + in_dim, + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + down_sample, + self_attention, + ) + ) + + self.up_blocks = nn.ModuleList([]) + for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(reversed(in_out_dims), *rev_layer_params) + ): + up_sample = level != 0 + self.up_blocks.append( + Kandinsky3UpSampleBlock( + in_dim, + cat_dims.pop(), + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + up_sample, + self_attention, + ) + ) + + self.conv_norm_out = nn.GroupNorm(groups, init_channels) + self.conv_act_out = nn.SiLU() + self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if not torch.is_tensor(timestep): + dtype = torch.float32 if isinstance(timestep, float) else torch.int32 + timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) + elif len(timestep.shape) == 0: + timestep = timestep[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = timestep.expand(sample.shape[0]) + time_embed_input = self.time_proj(timestep).to(sample.dtype) + time_embed = self.time_embedding(time_embed_input) + + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + + if encoder_hidden_states is not None: + time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask) + + hidden_states = [] + sample = self.conv_in(sample) + for level, down_sample in enumerate(self.down_blocks): + sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) + if level != self.num_levels - 1: + hidden_states.append(sample) + + for level, up_sample in enumerate(self.up_blocks): + if level != 0: + sample = torch.cat([sample, hidden_states.pop()], dim=1) + sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) + + sample = self.conv_norm_out(sample) + sample = self.conv_act_out(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + return Kandinsky3UNetOutput(sample=sample) + + +class Kandinsky3UpSampleBlock(nn.Module): + def __init__( + self, + in_channels, + cat_dim, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + up_sample=True, + self_attention=True, + ): + super().__init__() + up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1) + hidden_channels = ( + [(in_channels + cat_dim, in_channels)] + + [(in_channels, in_channels)] * (num_blocks - 2) + + [(in_channels, out_channels)] + ) + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) + ) + else: + attentions.append(nn.Identity()) + + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append( + Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) + ) + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) + ) + else: + attentions.append(nn.Identity()) + + resnets_out.append( + Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets_in = nn.ModuleList(resnets_in) + self.resnets_out = nn.ModuleList(resnets_out) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + return x + + +class Kandinsky3DownSampleBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + down_sample=True, + self_attention=True, + ): + super().__init__() + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) + ) + else: + attentions.append(nn.Identity()) + + up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]] + hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append( + Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) + ) + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) + ) + else: + attentions.append(nn.Identity()) + + resnets_out.append( + Kandinsky3ResNetBlock( + out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets_in = nn.ModuleList(resnets_in) + self.resnets_out = nn.ModuleList(resnets_out) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + return x + + +class Kandinsky3ConditionalGroupNorm(nn.Module): + def __init__(self, groups, normalized_shape, context_dim): + super().__init__() + self.norm = nn.GroupNorm(groups, normalized_shape, affine=False) + self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape)) + self.context_mlp[1].weight.data.zero_() + self.context_mlp[1].bias.data.zero_() + + def forward(self, x, context): + context = self.context_mlp(context) + + for _ in range(len(x.shape[2:])): + context = context.unsqueeze(-1) + + scale, shift = context.chunk(2, dim=1) + x = self.norm(x) * (scale + 1.0) + shift + return x + + +class Kandinsky3Block(nn.Module): + def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): + super().__init__() + self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) + self.activation = nn.SiLU() + if up_resolution is not None and up_resolution: + self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + else: + self.up_sample = nn.Identity() + + padding = int(kernel_size > 1) + self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) + + if up_resolution is not None and not up_resolution: + self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + else: + self.down_sample = nn.Identity() + + def forward(self, x, time_embed): + x = self.group_norm(x, time_embed) + x = self.activation(x) + x = self.up_sample(x) + x = self.projection(x) + x = self.down_sample(x) + return x + + +class Kandinsky3ResNetBlock(nn.Module): + def __init__( + self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None] + ): + super().__init__() + kernel_sizes = [1, 3, 3, 1] + hidden_channel = max(in_channels, out_channels) // compression_ratio + hidden_channels = ( + [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)] + ) + self.resnet_blocks = nn.ModuleList( + [ + Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution) + for (in_channel, out_channel), kernel_size, up_resolution in zip( + hidden_channels, kernel_sizes, up_resolutions + ) + ] + ) + self.shortcut_up_sample = ( + nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + if True in up_resolutions + else nn.Identity() + ) + self.shortcut_projection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() + ) + self.shortcut_down_sample = ( + nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + if False in up_resolutions + else nn.Identity() + ) + + def forward(self, x, time_embed): + out = x + for resnet_block in self.resnet_blocks: + out = resnet_block(out, time_embed) + + x = self.shortcut_up_sample(x) + x = self.shortcut_projection(x) + x = self.shortcut_down_sample(x) + x = x + out + return x + + +class Kandinsky3AttentionPooling(nn.Module): + def __init__(self, num_channels, context_dim, head_dim=64): + super().__init__() + self.attention = Attention( + context_dim, + context_dim, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) + + def forward(self, x, context, context_mask=None): + context_mask = context_mask.to(dtype=context.dtype) + context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) + return x + context.squeeze(1) + + +class Kandinsky3AttentionBlock(nn.Module): + def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): + super().__init__() + self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.attention = Attention( + num_channels, + context_dim or num_channels, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) + + hidden_channels = expansion_ratio * num_channels + self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.feed_forward = nn.Sequential( + nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False), + nn.SiLU(), + nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False), + ) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + height, width = x.shape[-2:] + out = self.in_norm(x, time_embed) + out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) + context = context if context is not None else out + if context_mask is not None: + context_mask = context_mask.to(dtype=context.dtype) + + out = self.attention(out, context, context_mask) + out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) + x = x + out + + out = self.out_norm(x, time_embed) + out = self.feed_forward(out) + x = x + out + return x diff --git a/diffusers/src/diffusers/models/unets/unet_motion_model.py b/diffusers/src/diffusers/models/unets/unet_motion_model.py new file mode 100755 index 0000000..6125feb --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_motion_model.py @@ -0,0 +1,2272 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..attention import BasicTransformerBlock +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + AttnProcessor2_0, + FusedAttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel +from .unet_2d_blocks import UNetMidBlock2DCrossAttn +from .unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMotionOutput(BaseOutput): + """ + The output of [`UNetMotionOutput`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor + + +class AnimateDiffTransformer3D(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + """ + The [`AnimateDiffTransformer3D`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Returns: + torch.Tensor: + The output tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(input=hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(input=hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + return output + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: Union[int, Tuple[int]] = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_double_self_attention: bool = True, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" + ) + + # support for variable number of attention head per temporal layers + if isinstance(temporal_num_attention_heads, int): + temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers + elif len(temporal_num_attention_heads) != num_layers: + raise ValueError( + f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads[i], + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads[i], + double_self_attention=temporal_double_self_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + num_frames: int = 1, + *args, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states=hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_double_self_attention: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + double_self_attention=temporal_double_self_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[torch.Tensor] = None, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states=hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size=None, + num_frames: int = 1, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + return hidden_states + + +class MotionModules(nn.Module): + def __init__( + self, + in_channels: int, + layers_per_block: int = 2, + transformer_layers_per_block: Union[int, Tuple[int]] = 8, + num_attention_heads: Union[int, Tuple[int]] = 8, + attention_bias: bool = False, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + norm_num_groups: int = 32, + max_seq_length: int = 32, + ): + super().__init__() + self.motion_modules = nn.ModuleList([]) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block + elif len(transformer_layers_per_block) != layers_per_block: + raise ValueError( + f"The number of transformer layers per block must match the number of layers per block, " + f"got {layers_per_block} and {len(transformer_layers_per_block)}" + ) + + for i in range(layers_per_block): + self.motion_modules.append( + AnimateDiffTransformer3D( + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, + positional_embeddings="sinusoidal", + num_positional_embeddings=max_seq_length, + ) + ) + + +class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + motion_layers_per_block: Union[int, Tuple[int]] = 2, + motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, + motion_mid_block_layers_per_block: int = 1, + motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1, + motion_num_attention_heads: Union[int, Tuple[int]] = 8, + motion_norm_num_groups: int = 32, + motion_max_seq_length: int = 32, + use_motion_mid_block: bool = True, + conv_in_channels: Optional[int] = None, + ): + """Container to store AnimateDiff Motion Modules + + Args: + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each UNet block. + motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2): + The number of motion layers per UNet block. + motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1): + The number of transformer layers to use in each motion layer in each block. + motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): + The number of motion layers in the middle UNet block. + motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer layers to use in each motion layer in the middle block. + motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8): + The number of heads to use in each attention layer of the motion module. + motion_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use in each group normalization layer of the motion module. + motion_max_seq_length (`int`, *optional*, defaults to 32): + The maximum sequence length to use in the motion module. + use_motion_mid_block (`bool`, *optional*, defaults to True): + Whether to use a motion module in the middle of the UNet. + """ + + super().__init__() + down_blocks = [] + up_blocks = [] + + if isinstance(motion_layers_per_block, int): + motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels) + elif len(motion_layers_per_block) != len(block_out_channels): + raise ValueError( + f"The number of motion layers per block must match the number of blocks, " + f"got {len(block_out_channels)} and {len(motion_layers_per_block)}" + ) + + if isinstance(motion_transformer_layers_per_block, int): + motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels) + + if isinstance(motion_transformer_layers_per_mid_block, int): + motion_transformer_layers_per_mid_block = ( + motion_transformer_layers_per_mid_block, + ) * motion_mid_block_layers_per_block + elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block: + raise ValueError( + f"The number of layers per mid block ({motion_mid_block_layers_per_block}) " + f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})" + ) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels) + elif len(motion_num_attention_heads) != len(block_out_channels): + raise ValueError( + f"The length of the attention head number tuple in the motion module must match the " + f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}" + ) + + if conv_in_channels: + # input + self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1) + else: + self.conv_in = None + + for i, channel in enumerate(block_out_channels): + output_channel = block_out_channels[i] + down_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads[i], + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block[i], + transformer_layers_per_block=motion_transformer_layers_per_block[i], + ) + ) + + if use_motion_mid_block: + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads[-1], + max_seq_length=motion_max_seq_length, + layers_per_block=motion_mid_block_layers_per_block, + transformer_layers_per_block=motion_transformer_layers_per_mid_block, + ) + else: + self.mid_block = None + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + + reversed_motion_layers_per_block = list(reversed(motion_layers_per_block)) + reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block)) + reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) + for i, channel in enumerate(reversed_block_out_channels): + output_channel = reversed_block_out_channels[i] + up_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=reversed_motion_num_attention_heads[i], + max_seq_length=motion_max_seq_length, + layers_per_block=reversed_motion_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i], + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + def forward(self, sample): + pass + + +class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a + sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + up_block_types: Tuple[str, ...] = ( + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, + temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, + transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, + temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1, + use_linear_projection: bool = False, + num_attention_heads: Union[int, Tuple[int, ...]] = 8, + motion_max_seq_length: int = 32, + motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8, + reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None, + use_motion_mid_block: bool = True, + mid_block_layers: int = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, + time_cond_proj_dim: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + if ( + isinstance(temporal_transformer_layers_per_block, list) + and reverse_temporal_transformer_layers_per_block is None + ): + for layer_number_per_block in temporal_transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError( + "Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim + ) + + if encoder_hid_dim_type is None: + self.encoder_hid_proj = None + + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if isinstance(reverse_transformer_layers_per_block, int): + reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types) + + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) + + if isinstance(reverse_temporal_transformer_layers_per_block, int): + reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len( + down_block_types + ) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + downsample_padding=downsample_padding, + add_downsample=not is_final_block, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + downsample_padding=downsample_padding, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) + + self.down_blocks.append(down_block) + + # mid + if transformer_layers_per_mid_block is None: + transformer_layers_per_mid_block = ( + transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 + ) + + if use_motion_mid_block: + self.mid_block = UNetMidBlockCrossAttnMotion( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + num_layers=mid_block_layers, + temporal_num_attention_heads=motion_num_attention_heads[-1], + temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=transformer_layers_per_mid_block, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block, + ) + + else: + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + num_layers=mid_block_layers, + transformer_layers_per_block=transformer_layers_per_mid_block, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) + + if reverse_transformer_layers_per_block is None: + reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + if reverse_temporal_transformer_layers_per_block is None: + reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + if up_block_type == "CrossAttnUpBlockMotion": + up_block = CrossAttnUpBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reverse_transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=reversed_num_attention_heads[i], + cross_attention_dim=reversed_cross_attention_dim[i], + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + elif up_block_type == "UpBlockMotion": + up_block = UpBlockMotion( + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=add_upsample, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`" + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @classmethod + def from_unet2d( + cls, + unet: UNet2DConditionModel, + motion_adapter: Optional[MotionAdapter] = None, + load_weights: bool = True, + ): + has_motion_adapter = motion_adapter is not None + + if has_motion_adapter: + motion_adapter.to(device=unet.device) + + # check compatibility of number of blocks + if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]): + raise ValueError("Incompatible Motion Adapter, got different number of blocks") + + # check layers compatibility for each block + if isinstance(unet.config["layers_per_block"], int): + expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"]) + else: + expanded_layers_per_block = list(unet.config["layers_per_block"]) + if isinstance(motion_adapter.config["motion_layers_per_block"], int): + expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len( + motion_adapter.config["block_out_channels"] + ) + else: + expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"]) + if expanded_layers_per_block != expanded_adapter_layers_per_block: + raise ValueError("Incompatible Motion Adapter, got different number of layers per block") + + # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 + config = dict(unet.config) + config["_class_name"] = cls.__name__ + + down_blocks = [] + for down_blocks_type in config["down_block_types"]: + if "CrossAttn" in down_blocks_type: + down_blocks.append("CrossAttnDownBlockMotion") + else: + down_blocks.append("DownBlockMotion") + config["down_block_types"] = down_blocks + + up_blocks = [] + for down_blocks_type in config["up_block_types"]: + if "CrossAttn" in down_blocks_type: + up_blocks.append("CrossAttnUpBlockMotion") + else: + up_blocks.append("UpBlockMotion") + config["up_block_types"] = up_blocks + + if has_motion_adapter: + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] + config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"] + config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[ + "motion_transformer_layers_per_mid_block" + ] + config["temporal_transformer_layers_per_block"] = motion_adapter.config[ + "motion_transformer_layers_per_block" + ] + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + + # For PIA UNets we need to set the number input channels to 9 + if motion_adapter.config["conv_in_channels"]: + config["in_channels"] = motion_adapter.config["conv_in_channels"] + + # Need this for backwards compatibility with UNet2DConditionModel checkpoints + if not config.get("num_attention_heads"): + config["num_attention_heads"] = config["attention_head_dim"] + + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs}) + config["_class_name"] = cls.__name__ + model = cls.from_config(config) + + if not load_weights: + return model + + # Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight + # while the last 5 channels must be PIA conv_in weights. + if has_motion_adapter and motion_adapter.config["conv_in_channels"]: + model.conv_in = motion_adapter.conv_in + updated_conv_in_weight = torch.cat( + [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 + ) + model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) + else: + model.conv_in.load_state_dict(unet.conv_in.state_dict()) + + model.time_proj.load_state_dict(unet.time_proj.state_dict()) + model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if any( + isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + for proc in unet.attn_processors.values() + ): + attn_procs = {} + for name, processor in unet.attn_processors.items(): + if name.endswith("attn1.processor"): + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=processor.hidden_size, + cross_attention_dim=processor.cross_attention_dim, + scale=processor.scale, + num_tokens=processor.num_tokens, + ) + for name, processor in model.attn_processors.items(): + if name not in attn_procs: + attn_procs[name] = processor.__class__() + model.set_attn_processor(attn_procs) + model.config.encoder_hid_dim_type = "ip_image_proj" + model.encoder_hid_proj = unet.encoder_hid_proj + + for i, down_block in enumerate(unet.down_blocks): + model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) + if hasattr(model.down_blocks[i], "attentions"): + model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) + if model.down_blocks[i].downsamplers: + model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) + + for i, up_block in enumerate(unet.up_blocks): + model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) + if hasattr(model.up_blocks[i], "attentions"): + model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) + if model.up_blocks[i].upsamplers: + model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) + + model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) + model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + + if unet.conv_norm_out is not None: + model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) + if unet.conv_act is not None: + model.conv_act.load_state_dict(unet.conv_act.state_dict()) + model.conv_out.load_state_dict(unet.conv_out.state_dict()) + + if has_motion_adapter: + model.load_motion_modules(motion_adapter) + + # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def freeze_unet2d_params(self) -> None: + """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules + unfrozen for fine tuning. + """ + # Freeze everything + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze Motion Modules + for down_block in self.down_blocks: + motion_modules = down_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + for up_block in self.up_blocks: + motion_modules = up_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + if hasattr(self.mid_block, "motion_modules"): + motion_modules = self.mid_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: + for i, down_block in enumerate(motion_adapter.down_blocks): + self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) + for i, up_block in enumerate(motion_adapter.up_blocks): + self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict()) + + # to support older motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) + + def save_motion_modules( + self, + save_directory: str, + is_main_process: bool = True, + safe_serialization: bool = True, + variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + state_dict = self.state_dict() + + # Extract all motion modules + motion_state_dict = {} + for k, v in state_dict.items(): + if "motion_modules" in k: + motion_state_dict[k] = v + + adapter = MotionAdapter( + block_out_channels=self.config["block_out_channels"], + motion_layers_per_block=self.config["layers_per_block"], + motion_norm_num_groups=self.config["norm_num_groups"], + motion_num_attention_heads=self.config["motion_num_attention_heads"], + motion_max_seq_length=self.config["motion_max_seq_length"], + use_motion_mid_block=self.config["use_motion_mid_block"], + ) + adapter.load_state_dict(motion_state_dict) + adapter.save_pretrained( + save_directory=save_directory, + is_main_process=is_main_process, + safe_serialization=safe_serialization, + variant=variant, + push_to_hub=push_to_hub, + **kwargs, + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self) -> None: + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self) -> None: + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self) -> None: + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]: + r""" + The [`UNetMotionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb if aug_emb is None else emb + aug_emb + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + # To support older versions of motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNetMotionOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/diffusers/src/diffusers/models/unets/unet_spatio_temporal_condition.py new file mode 100755 index 0000000..9fb975b --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -0,0 +1,490 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetSpatioTemporalConditionOutput(BaseOutput): + """ + The output of [`UNetSpatioTemporalConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and + returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: Tuple[str] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: Union[int, Tuple[int]] = 2, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), + num_frames: int = 25, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + padding=1, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-5, + resolution_idx=i, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + resnet_act_fn="silu", + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) + self.conv_act = nn.SiLU() + + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + added_time_ids: torch.Tensor, + return_dict: bool = True, + ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`torch.Tensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead + of a plain tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is + returned, otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + emb = emb + aug_emb + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + # 2. pre-process + sample = self.conv_in(sample) + + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + image_only_indicator=image_only_indicator, + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 7. Reshape back to original shape + sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNetSpatioTemporalConditionOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/unet_stable_cascade.py b/diffusers/src/diffusers/models/unets/unet_stable_cascade.py new file mode 100755 index 0000000..7deea9a --- /dev/null +++ b/diffusers/src/diffusers/models/unets/unet_stable_cascade.py @@ -0,0 +1,611 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import BaseOutput +from ..attention_processor import Attention +from ..modeling_utils import ModelMixin + + +# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm +class SDCascadeLayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + return x.permute(0, 3, 1, 2) + + +class SDCascadeTimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=[]): + super().__init__() + + self.mapper = nn.Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b + + +class SDCascadeResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class SDCascadeAttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + + self.self_attn = self_attn + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) + return x + + +class UpDownBlock2d(nn.Module): + def __init__(self, in_channels, out_channels, mode, enabled=True): + super().__init__() + if mode not in ["up", "down"]: + raise ValueError(f"{mode} not supported") + interpolation = ( + nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) + if enabled + else nn.Identity() + ) + mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +@dataclass +class StableCascadeUNetOutput(BaseOutput): + sample: torch.Tensor = None + + +class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + timestep_ratio_embedding_dim: int = 64, + patch_size: int = 1, + conditioning_dim: int = 2048, + block_out_channels: Tuple[int] = (2048, 2048), + num_attention_heads: Tuple[int] = (32, 32), + down_num_layers_per_block: Tuple[int] = (8, 24), + up_num_layers_per_block: Tuple[int] = (24, 8), + down_blocks_repeat_mappers: Optional[Tuple[int]] = ( + 1, + 1, + ), + up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1), + block_types_per_layer: Tuple[Tuple[str]] = ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ), + clip_text_in_channels: Optional[int] = None, + clip_text_pooled_in_channels=1280, + clip_image_in_channels: Optional[int] = None, + clip_seq=4, + effnet_in_channels: Optional[int] = None, + pixel_mapper_in_channels: Optional[int] = None, + kernel_size=3, + dropout: Union[float, Tuple[float]] = (0.1, 0.1), + self_attn: Union[bool, Tuple[bool]] = True, + timestep_conditioning_type: Tuple[str] = ("sca", "crp"), + switch_level: Optional[Tuple[bool]] = None, + ): + """ + + Parameters: + in_channels (`int`, defaults to 16): + Number of channels in the input sample. + out_channels (`int`, defaults to 16): + Number of channels in the output sample. + timestep_ratio_embedding_dim (`int`, defaults to 64): + Dimension of the projected time embedding. + patch_size (`int`, defaults to 1): + Patch size to use for pixel unshuffling layer + conditioning_dim (`int`, defaults to 2048): + Dimension of the image and text conditional embedding. + block_out_channels (Tuple[int], defaults to (2048, 2048)): + Tuple of output channels for each block. + num_attention_heads (Tuple[int], defaults to (32, 32)): + Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have + attention. + down_num_layers_per_block (Tuple[int], defaults to [8, 24]): + Number of layers in each down block. + up_num_layers_per_block (Tuple[int], defaults to [24, 8]): + Number of layers in each up block. + down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each down block. + up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each up block. + block_types_per_layer (Tuple[Tuple[str]], optional, + defaults to ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock", + "SDCascadeTimestepBlock", "SDCascadeAttnBlock") + ): Block types used in each layer of the up/down blocks. + clip_text_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for CLIP based text conditioning. + clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280): + Number of input channels for pooled CLIP text embeddings. + clip_image_in_channels (`int`, *optional*): + Number of input channels for CLIP based image conditioning. + clip_seq (`int`, *optional*, defaults to 4): + effnet_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for effnet conditioning. + pixel_mapper_in_channels (`int`, defaults to `None`): + Number of input channels for pixel mapper conditioning. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size to use in the block convolutional layers. + dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)): + Dropout to use per block. + self_attn (Union[bool, Tuple[bool]]): + Tuple of booleans that determine whether to use self attention in a block or not. + timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")): + Timestep conditioning type. + switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`): + Tuple that indicates whether upsampling or downsampling should be applied in a block + """ + + super().__init__() + + if len(block_out_channels) != len(down_num_layers_per_block): + raise ValueError( + f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_num_layers_per_block): + raise ValueError( + f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(down_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(block_types_per_layer): + raise ValueError( + f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + if isinstance(dropout, float): + dropout = (dropout,) * len(block_out_channels) + if isinstance(self_attn, bool): + self_attn = (self_attn,) * len(block_out_channels) + + # CONDITIONING + if effnet_in_channels is not None: + self.effnet_mapper = nn.Sequential( + nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + if pixel_mapper_in_channels is not None: + self.pixels_mapper = nn.Sequential( + nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + + self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq) + if clip_text_in_channels is not None: + self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim) + if clip_image_in_channels is not None: + self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq) + self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == "SDCascadeResBlock": + return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "SDCascadeAttnBlock": + return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "SDCascadeTimestepBlock": + return SDCascadeTimestepBlock( + in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type + ) + else: + raise ValueError(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(block_out_channels)): + if i > 0: + self.down_downscalers.append( + nn.Sequential( + SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2), + ) + ) + else: + self.down_downscalers.append(nn.Identity()) + + down_block = nn.ModuleList() + for _ in range(down_num_layers_per_block[i]): + for block_type in block_types_per_layer[i]: + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + dropout=dropout[i], + self_attn=self_attn[i], + ) + down_block.append(block) + self.down_blocks.append(down_block) + + if down_blocks_repeat_mappers is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(down_blocks_repeat_mappers[i] - 1): + block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(block_out_channels))): + if i > 0: + self.up_upscalers.append( + nn.Sequential( + SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.ConvTranspose2d( + block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2 + ), + ) + ) + else: + self.up_upscalers.append(nn.Identity()) + + up_block = nn.ModuleList() + for j in range(up_num_layers_per_block[::-1][i]): + for k, block_type in enumerate(block_types_per_layer[i]): + c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0 + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + c_skip=c_skip, + dropout=dropout[i], + self_attn=self_attn[i], + ) + up_block.append(block) + self.up_blocks.append(up_block) + + if up_blocks_repeat_mappers is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(up_blocks_repeat_mappers[::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, value=False): + self.gradient_checkpointing = value + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None + + if hasattr(self, "effnet_mapper"): + nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + + if hasattr(self, "pixels_mapper"): + nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, SDCascadeResBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0])) + elif isinstance(block, SDCascadeTimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000): + r = timestep_ratio * max_positions + half_dim = self.config.timestep_ratio_embedding_dim // 2 + + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + + if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + + return emb.to(dtype=r.dtype) + + def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None): + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( + clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1 + ) + if clip_txt is not None and clip_img is not None: + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_img = self.clip_img_mapper(clip_img).view( + clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1 + ) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + else: + clip = clip_txt_pool + return self.clip_norm(clip) + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, SDCascadeResBlock): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + elif isinstance(block, SDCascadeAttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, clip, use_reentrant=False + ) + elif isinstance(block, SDCascadeTimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed, use_reentrant=False + ) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + else: + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, SDCascadeResBlock): + x = block(x) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, SDCascadeResBlock): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype + x = torch.nn.functional.interpolate( + x.float(), skip.shape[-2:], mode="bilinear", align_corners=True + ) + x = x.to(orig_type) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, skip, use_reentrant=False + ) + elif isinstance(block, SDCascadeAttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, clip, use_reentrant=False + ) + elif isinstance(block, SDCascadeTimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed, use_reentrant=False + ) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + else: + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, SDCascadeResBlock): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype + x = torch.nn.functional.interpolate( + x.float(), skip.shape[-2:], mode="bilinear", align_corners=True + ) + x = x.to(orig_type) + x = block(x, skip) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward( + self, + sample, + timestep_ratio, + clip_text_pooled, + clip_text=None, + clip_img=None, + effnet=None, + pixels=None, + sca=None, + crp=None, + return_dict=True, + ): + if pixels is None: + pixels = sample.new_zeros(sample.size(0), 3, 8, 8) + + # Process the conditioning embeddings + timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio) + for c in self.config.timestep_conditioning_type: + if c == "sca": + cond = sca + elif c == "crp": + cond = crp + else: + cond = None + t_cond = cond or torch.zeros_like(timestep_ratio) + timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1) + clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img) + + # Model Blocks + x = self.embedding(sample) + if hasattr(self, "effnet_mapper") and effnet is not None: + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True) + ) + if hasattr(self, "pixels_mapper"): + x = x + nn.functional.interpolate( + self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True + ) + level_outputs = self._down_encode(x, timestep_ratio_embed, clip) + x = self._up_decode(level_outputs, timestep_ratio_embed, clip) + sample = self.clf(x) + + if not return_dict: + return (sample,) + return StableCascadeUNetOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/unets/uvit_2d.py b/diffusers/src/diffusers/models/unets/uvit_2d.py new file mode 100755 index 0000000..8a379bf --- /dev/null +++ b/diffusers/src/diffusers/models/unets/uvit_2d.py @@ -0,0 +1,470 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ..attention import BasicTransformerBlock, SkipFFTransformerBlock +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TimestepEmbedding, get_timestep_embedding +from ..modeling_utils import ModelMixin +from ..normalization import GlobalResponseNorm, RMSNorm +from ..resnet import Downsample2D, Upsample2D + + +class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # global config + hidden_size: int = 1024, + use_bias: bool = False, + hidden_dropout: float = 0.0, + # conditioning dimensions + cond_embed_dim: int = 768, + micro_cond_encode_dim: int = 256, + micro_cond_embed_dim: int = 1280, + encoder_hidden_size: int = 768, + # num tokens + vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded + codebook_size: int = 8192, + # `UVit2DConvEmbed` + in_channels: int = 768, + block_out_channels: int = 768, + num_res_blocks: int = 3, + downsample: bool = False, + upsample: bool = False, + block_num_heads: int = 12, + # `TransformerLayer` + num_hidden_layers: int = 22, + num_attention_heads: int = 16, + # `Attention` + attention_dropout: float = 0.0, + # `FeedForward` + intermediate_size: int = 2816, + # `Norm` + layer_norm_eps: float = 1e-6, + ln_elementwise_affine: bool = True, + sample_size: int = 64, + ): + super().__init__() + + self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) + self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + + self.embed = UVit2DConvEmbed( + in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias + ) + + self.cond_embed = TimestepEmbedding( + micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias + ) + + self.down_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample, + False, + ) + + self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) + self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) + + self.transformer_layers = nn.ModuleList( + [ + BasicTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=hidden_size // num_attention_heads, + dropout=hidden_dropout, + cross_attention_dim=hidden_size, + attention_bias=use_bias, + norm_type="ada_norm_continuous", + ada_norm_continous_conditioning_embedding_dim=hidden_size, + norm_elementwise_affine=ln_elementwise_affine, + norm_eps=layer_norm_eps, + ada_norm_bias=use_bias, + ff_inner_dim=intermediate_size, + ff_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_hidden_layers) + ] + ) + + self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) + + self.up_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample=False, + upsample=upsample, + ) + + self.mlm_layer = ConvMlmLayer( + block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + pass + + def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): + encoder_hidden_states = self.encoder_proj(encoder_hidden_states) + encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) + + micro_cond_embeds = get_timestep_embedding( + micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) + + pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) + pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) + pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) + + hidden_states = self.embed(input_ids) + + hidden_states = self.down_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) + + hidden_states = self.project_to_hidden_norm(hidden_states) + hidden_states = self.project_to_hidden(hidden_states) + + for layer in self.transformer_layers: + if self.training and self.gradient_checkpointing: + + def layer_(*args): + return checkpoint(layer, *args) + + else: + layer_ = layer + + hidden_states = layer_( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, + ) + + hidden_states = self.project_from_hidden_norm(hidden_states) + hidden_states = self.project_from_hidden(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + + hidden_states = self.up_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + logits = self.mlm_layer(hidden_states) + + return logits + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + +class UVit2DConvEmbed(nn.Module): + def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): + super().__init__() + self.embeddings = nn.Embedding(vocab_size, in_channels) + self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) + self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) + + def forward(self, input_ids): + embeddings = self.embeddings(input_ids) + embeddings = self.layer_norm(embeddings) + embeddings = embeddings.permute(0, 3, 1, 2) + embeddings = self.conv(embeddings) + return embeddings + + +class UVitBlock(nn.Module): + def __init__( + self, + channels, + num_res_blocks: int, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample: bool, + upsample: bool, + ): + super().__init__() + + if downsample: + self.downsample = Downsample2D( + channels, + use_conv=True, + padding=0, + name="Conv2d_0", + kernel_size=2, + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + ) + else: + self.downsample = None + + self.res_blocks = nn.ModuleList( + [ + ConvNextBlock( + channels, + layer_norm_eps, + ln_elementwise_affine, + use_bias, + hidden_dropout, + hidden_size, + ) + for i in range(num_res_blocks) + ] + ) + + self.attention_blocks = nn.ModuleList( + [ + SkipFFTransformerBlock( + channels, + block_num_heads, + channels // block_num_heads, + hidden_size, + use_bias, + attention_dropout, + channels, + attention_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_res_blocks) + ] + ) + + if upsample: + self.upsample = Upsample2D( + channels, + use_conv_transpose=True, + kernel_size=2, + padding=0, + name="conv", + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + interpolate=False, + ) + else: + self.upsample = None + + def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): + if self.downsample is not None: + x = self.downsample(x) + + for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): + x = res_block(x, pooled_text_emb) + + batch_size, channels, height, width = x.shape + x = x.view(batch_size, channels, height * width).permute(0, 2, 1) + x = attention_block( + x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs + ) + x = x.permute(0, 2, 1).view(batch_size, channels, height, width) + + if self.upsample is not None: + x = self.upsample(x) + + return x + + +class ConvNextBlock(nn.Module): + def __init__( + self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 + ): + super().__init__() + self.depthwise = nn.Conv2d( + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=use_bias, + ) + self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) + self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) + self.channelwise_act = nn.GELU() + self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) + self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) + self.channelwise_dropout = nn.Dropout(hidden_dropout) + self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) + + def forward(self, x, cond_embeds): + x_res = x + + x = self.depthwise(x) + + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + + x = self.channelwise_linear_1(x) + x = self.channelwise_act(x) + x = self.channelwise_norm(x) + x = self.channelwise_linear_2(x) + x = self.channelwise_dropout(x) + + x = x.permute(0, 3, 1, 2) + + x = x + x_res + + scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) + x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] + + return x + + +class ConvMlmLayer(nn.Module): + def __init__( + self, + block_out_channels: int, + in_channels: int, + use_bias: bool, + ln_elementwise_affine: bool, + layer_norm_eps: float, + codebook_size: int, + ): + super().__init__() + self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) + self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) + self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + logits = self.conv2(hidden_states) + return logits diff --git a/diffusers/src/diffusers/models/upsampling.py b/diffusers/src/diffusers/models/upsampling.py new file mode 100755 index 0000000..fd5ed28 --- /dev/null +++ b/diffusers/src/diffusers/models/upsampling.py @@ -0,0 +1,509 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import deprecate +from .normalization import RMSNorm + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 1D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class FirUpsample2D(nn.Module): + """A 2D FIR upsampling layer with an optional convolution. + + Parameters: + channels (`int`, optional): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__( + self, + channels: Optional[int] = None, + out_channels: Optional[int] = None, + use_conv: bool = False, + fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d( + self, + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + kernel: Optional[torch.Tensor] = None, + factor: int = 2, + gain: float = 1, + ) -> torch.Tensor: + """Fused `upsample_2d()` followed by `Conv2d()`. + + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.Tensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*): Integer upsampling factor (default: 2). + gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + + pad_value = (kernel.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convH, + (hidden_states.shape[3] - 1) * factor + convW, + ) + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = hidden_states.shape[1] // inC + + # Transpose weights. + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + + inverse_conv = F.conv_transpose2d( + hidden_states, + weight, + stride=stride, + output_padding=output_padding, + padding=0, + ) + + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + + return output + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.use_conv: + height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return height + + +class KUpsample2D(nn.Module): + r"""A 2D K-upsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = inputs.new_zeros( + [ + inputs.shape[1], + inputs.shape[1], + self.kernel.shape[0], + self.kernel.shape[1], + ] + ) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) + + +class CogVideoXUpsample3D(nn.Module): + r""" + A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `1`): + Stride of the convolution. + padding (`int`, defaults to `1`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + compress_time: bool = False, + ) -> None: + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if self.compress_time: + if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + inputs = torch.cat([x_first, x_rest], dim=2) + elif inputs.shape[2] > 1: + inputs = F.interpolate(inputs, scale_factor=2.0) + else: + inputs = inputs.squeeze(2) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + # only interpolate 2D + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = self.conv(inputs) + inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) + + return inputs + + +def upfirdn2d_native( + tensor: torch.Tensor, + kernel: torch.Tensor, + up: int = 1, + down: int = 1, + pad: Tuple[int, int] = (0, 0), +) -> torch.Tensor: + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(tensor.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +def upsample_2d( + hidden_states: torch.Tensor, + kernel: Optional[torch.Tensor] = None, + factor: int = 2, + gain: float = 1, +) -> torch.Tensor: + r"""Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*, default to `2`): + Integer upsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output diff --git a/diffusers/src/diffusers/models/vae_flax.py b/diffusers/src/diffusers/models/vae_flax.py new file mode 100755 index 0000000..5027f42 --- /dev/null +++ b/diffusers/src/diffusers/models/vae_flax.py @@ -0,0 +1,876 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers + +import math +from functools import partial +from typing import Tuple + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .modeling_flax_utils import FlaxModelMixin + + +@flax.struct.dataclass +class FlaxDecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The `dtype` of the parameters. + """ + + sample: jnp.ndarray + + +@flax.struct.dataclass +class FlaxAutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`FlaxDiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. + `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "FlaxDiagonalGaussianDistribution" + + +class FlaxUpsample2D(nn.Module): + """ + Flax implementation of 2D Upsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + """ + Flax implementation of 2D Downsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + """ + Flax implementation of 2D Resnet Block. + + Args: + in_channels (`int`): + Input channels + out_channels (`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for group norm. + use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): + Whether to use `nin_shortcut`. This activates a new layer inside ResNet block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int = None + dropout: float = 0.0 + groups: int = 32 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) + self.dropout_layer = nn.Dropout(self.dropout) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class FlaxAttentionBlock(nn.Module): + r""" + Flax Convolutional based multi-head attention block for diffusion-based VAE. + + Parameters: + channels (:obj:`int`): + Input channels + num_head_channels (:obj:`int`, *optional*, defaults to `None`): + Number of attention heads + num_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for group norm + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ + + channels: int + num_head_channels: int = None + num_groups: int = 32 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 + + dense = partial(nn.Dense, self.channels, dtype=self.dtype) + + self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) + self.query, self.key, self.value = dense(), dense(), dense() + self.proj_attn = dense() + + def transpose_for_scores(self, projection): + new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) + new_projection = projection.reshape(new_projection_shape) + # (B, T, H, D) -> (B, H, T, D) + new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) + return new_projection + + def __call__(self, hidden_states): + residual = hidden_states + batch, height, width, channels = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.reshape((batch, height * width, channels)) + + query = self.query(hidden_states) + key = self.key(hidden_states) + value = self.value(hidden_states) + + # transpose + query = self.transpose_for_scores(query) + key = self.transpose_for_scores(key) + value = self.transpose_for_scores(value) + + # compute attentions + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) + attn_weights = nn.softmax(attn_weights, axis=-1) + + # attend to values + hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) + + hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) + new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) + hidden_states = hidden_states.reshape(new_hidden_states_shape) + + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.reshape((batch, height, width, channels)) + hidden_states = hidden_states + residual + return hidden_states + + +class FlaxDownEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet block group norm + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + resnet_groups: int = 32 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout=self.dropout, + groups=self.resnet_groups, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, deterministic=deterministic) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUpDecoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Decoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet block group norm + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + resnet_groups: int = 32 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout=self.dropout, + groups=self.resnet_groups, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2D(nn.Module): + r""" + Flax Unet Mid-Block module. + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet and Attention block group norm + num_attention_heads (:obj:`int`, *optional*, defaults to `1`): + Number of attention heads for each attention block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + resnet_groups: int = 32 + num_attention_heads: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout=self.dropout, + groups=resnet_groups, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxAttentionBlock( + channels=self.in_channels, + num_head_channels=self.num_attention_heads, + num_groups=resnet_groups, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout=self.dropout, + groups=resnet_groups, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, deterministic=deterministic) + + return hidden_states + + +class FlaxEncoder(nn.Module): + r""" + Flax Implementation of VAE Encoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int = 3 + out_channels: int = 3 + down_block_types: Tuple[str] = ("DownEncoderBlock2D",) + block_out_channels: Tuple[int] = (64,) + layers_per_block: int = 2 + norm_num_groups: int = 32 + act_fn: str = "silu" + double_z: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + block_out_channels = self.block_out_channels + # in + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # downsampling + down_blocks = [] + output_channel = block_out_channels[0] + for i, _ in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = FlaxDownEncoderBlock2D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + resnet_groups=self.norm_num_groups, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # middle + self.mid_block = FlaxUNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + num_attention_heads=None, + dtype=self.dtype, + ) + + # end + conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) + self.conv_out = nn.Conv( + conv_out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, deterministic: bool = True): + # in + sample = self.conv_in(sample) + + # downsampling + for block in self.down_blocks: + sample = block(sample, deterministic=deterministic) + + # middle + sample = self.mid_block(sample, deterministic=deterministic) + + # end + sample = self.conv_norm_out(sample) + sample = nn.swish(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxDecoder(nn.Module): + r""" + Flax Implementation of VAE Decoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ + + in_channels: int = 3 + out_channels: int = 3 + up_block_types: Tuple[str] = ("UpDecoderBlock2D",) + block_out_channels: int = (64,) + layers_per_block: int = 2 + norm_num_groups: int = 32 + act_fn: str = "silu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + block_out_channels = self.block_out_channels + + # z to block_in + self.conv_in = nn.Conv( + block_out_channels[-1], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # middle + self.mid_block = FlaxUNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + num_attention_heads=None, + dtype=self.dtype, + ) + + # upsampling + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + up_blocks = [] + for i, _ in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = FlaxUpDecoderBlock2D( + in_channels=prev_output_channel, + out_channels=output_channel, + num_layers=self.layers_per_block + 1, + resnet_groups=self.norm_num_groups, + add_upsample=not is_final_block, + dtype=self.dtype, + ) + up_blocks.append(up_block) + prev_output_channel = output_channel + + self.up_blocks = up_blocks + + # end + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, deterministic: bool = True): + # z to block_in + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample, deterministic=deterministic) + + # upsampling + for block in self.up_blocks: + sample = block(sample, deterministic=deterministic) + + sample = self.conv_norm_out(sample) + sample = nn.swish(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxDiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + # Last axis to account for channels-last + self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) + self.logvar = jnp.clip(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = jnp.exp(0.5 * self.logvar) + self.var = jnp.exp(self.logvar) + if self.deterministic: + self.var = self.std = jnp.zeros_like(self.mean) + + def sample(self, key): + return self.mean + self.std * jax.random.normal(key, self.mean.shape) + + def kl(self, other=None): + if self.deterministic: + return jnp.array([0.0]) + + if other is None: + return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) + + return 0.5 * jnp.sum( + jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, + axis=[1, 2, 3], + ) + + def nll(self, sample, axis=[1, 2, 3]): + if self.deterministic: + return jnp.array([0.0]) + + logtwopi = jnp.log(2.0 * jnp.pi) + return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) + + def mode(self): + return self.mean + + +@flax_register_to_config +class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + Flax implementation of a VAE model with KL loss for decoding latent representations. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its + general usage and behavior. + + Inherent JAX features such as the following are supported: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): + Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + Tuple of upsample block types. + block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): + Number of ResNet layer for each block. + act_fn (`str`, *optional*, defaults to `silu`): + The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): + Number of channels in the latent space. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups for normalization. + sample_size (`int`, *optional*, defaults to 32): + Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The `dtype` of the parameters. + """ + + in_channels: int = 3 + out_channels: int = 3 + down_block_types: Tuple[str] = ("DownEncoderBlock2D",) + up_block_types: Tuple[str] = ("UpDecoderBlock2D",) + block_out_channels: Tuple[int] = (64,) + layers_per_block: int = 1 + act_fn: str = "silu" + latent_channels: int = 4 + norm_num_groups: int = 32 + sample_size: int = 32 + scaling_factor: float = 0.18215 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.encoder = FlaxEncoder( + in_channels=self.config.in_channels, + out_channels=self.config.latent_channels, + down_block_types=self.config.down_block_types, + block_out_channels=self.config.block_out_channels, + layers_per_block=self.config.layers_per_block, + act_fn=self.config.act_fn, + norm_num_groups=self.config.norm_num_groups, + double_z=True, + dtype=self.dtype, + ) + self.decoder = FlaxDecoder( + in_channels=self.config.latent_channels, + out_channels=self.config.out_channels, + up_block_types=self.config.up_block_types, + block_out_channels=self.config.block_out_channels, + layers_per_block=self.config.layers_per_block, + norm_num_groups=self.config.norm_num_groups, + act_fn=self.config.act_fn, + dtype=self.dtype, + ) + self.quant_conv = nn.Conv( + 2 * self.config.latent_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + self.post_quant_conv = nn.Conv( + self.config.latent_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def init_weights(self, rng: jax.Array) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + + params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) + rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} + + return self.init(rngs, sample)["params"] + + def encode(self, sample, deterministic: bool = True, return_dict: bool = True): + sample = jnp.transpose(sample, (0, 2, 3, 1)) + + hidden_states = self.encoder(sample, deterministic=deterministic) + moments = self.quant_conv(hidden_states) + posterior = FlaxDiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return FlaxAutoencoderKLOutput(latent_dist=posterior) + + def decode(self, latents, deterministic: bool = True, return_dict: bool = True): + if latents.shape[-1] != self.config.latent_channels: + latents = jnp.transpose(latents, (0, 2, 3, 1)) + + hidden_states = self.post_quant_conv(latents) + hidden_states = self.decoder(hidden_states, deterministic=deterministic) + + hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) + + if not return_dict: + return (hidden_states,) + + return FlaxDecoderOutput(sample=hidden_states) + + def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): + posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) + if sample_posterior: + rng = self.make_rng("gaussian") + hidden_states = posterior.latent_dist.sample(rng) + else: + hidden_states = posterior.latent_dist.mode() + + sample = self.decode(hidden_states, return_dict=return_dict).sample + + if not return_dict: + return (sample,) + + return FlaxDecoderOutput(sample=sample) diff --git a/diffusers/src/diffusers/models/vq_model.py b/diffusers/src/diffusers/models/vq_model.py new file mode 100755 index 0000000..f946e46 --- /dev/null +++ b/diffusers/src/diffusers/models/vq_model.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..utils import deprecate +from .autoencoders.vq_model import VQEncoderOutput, VQModel + + +class VQEncoderOutput(VQEncoderOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead." + deprecate("VQEncoderOutput", "0.31", deprecation_message) + super().__init__(*args, **kwargs) + + +class VQModel(VQModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead." + deprecate("VQModel", "0.31", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/diffusers/src/diffusers/optimization.py b/diffusers/src/diffusers/optimization.py new file mode 100755 index 0000000..f20bd94 --- /dev/null +++ b/diffusers/src/diffusers/optimization.py @@ -0,0 +1,361 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for diffusion models.""" + +import math +from enum import Enum +from typing import Optional, Union + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class SchedulerType(Enum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + PIECEWISE_CONSTANT = "piecewise_constant" + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR: + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR: + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR: + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + step_rules (`string`): + The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate + if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + steps and multiple 0.005 for the other steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + rules_dict = {} + rule_list = step_rules.split(",") + for rule_str in rule_list[:-1]: + value_str, steps_str = rule_str.split(":") + steps = int(steps_str) + value = float(value_str) + rules_dict[steps] = value + last_lr_multiple = float(rule_list[-1]) + + def create_rules_function(rules_dict, last_lr_multiple): + def rule_func(steps: int) -> float: + sorted_steps = sorted(rules_dict.keys()) + for i, sorted_step in enumerate(sorted_steps): + if steps < sorted_step: + return rules_dict[sorted_steps[i]] + return last_lr_multiple + + return rule_func + + rules_func = create_rules_function(rules_dict, last_lr_multiple) + + return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) + + +def get_linear_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1 +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_polynomial_decay_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float = 1e-7, + power: float = 1.0, + last_epoch: int = -1, +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + step_rules: Optional[str] = None, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, + last_epoch: int = -1, +) -> LambdaLR: + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + step_rules (`str`, *optional*): + A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer, last_epoch=last_epoch) + + if name == SchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + last_epoch=last_epoch, + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + power=power, + last_epoch=last_epoch, + ) + + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch + ) diff --git a/diffusers/src/diffusers/pipelines/README.md b/diffusers/src/diffusers/pipelines/README.md new file mode 100755 index 0000000..b2954c0 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/README.md @@ -0,0 +1,171 @@ +# 🧨 Diffusers Pipelines + +Pipelines provide a simple way to run state-of-the-art diffusion models in inference. +Most diffusion systems consist of multiple independently-trained models and highly adaptable scheduler +components - all of which are needed to have a functioning end-to-end diffusion system. + +As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models: +- [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392) +- [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12) +- [CLIP text encoder](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel) +- a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py), +- a [CLIPImageProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPImageProcessor), +- as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py). +All of these components are necessary to run stable diffusion in inference even though they were trained +or created independently from each other. + +To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API. +More specifically, we strive to provide pipelines that +- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)), +- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section), +- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)), +- 4. can easily be contributed by the community (see the [Contribution](#contribution) section). + +**Note** that pipelines do not (and should not) offer any training functionality. +If you are looking for *official* training examples, please have a look at [examples](https://github.com/huggingface/diffusers/tree/main/examples). + + +## Pipelines Summary + +The following table summarizes all officially supported pipelines, their corresponding paper, and if +available a colab notebook to directly try them out. + +| Pipeline | Source | Tasks | Colab +|-------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:---:|:---:| +| [dance diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/Harmonai-org/sample-generator) | *Unconditional Audio Generation* | +| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* | +| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Text-to-Image Generation* | +| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* | +| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | +| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | +| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | + +**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. +However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below. + +## Pipelines API + +Diffusion models often consist of multiple independently-trained models or other previously existing components. + + +Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one. +During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality: + +- [`from_pretrained` method](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L139) that accepts a Hugging Face Hub repository id, *e.g.* [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) or a path to a local directory, *e.g.* +"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [stable-diffusion-v1-5/stable-diffusion-v1-5/model_index.json](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be +loaded into the pipelines. More specifically, for each model/component one needs to define the format `: ["", ""]`. `` is the attribute name given to the loaded instance of `` which can be found in the library or pipeline folder called `""`. +- [`save_pretrained`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L90) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`. +In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated +from the local path. +- [`to`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L118) which accepts a `string` or `torch.device` to move all models that are of type `torch.nn.Module` to the passed device. The behavior is fully analogous to [PyTorch's `to` method](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to). +- [`__call__`] method to use the pipeline in inference. `__call__` defines inference logic of the pipeline and should ideally encompass all aspects of it, from pre-processing to forwarding tensors to the different models and schedulers, as well as post-processing. The API of the `__call__` method can strongly vary from pipeline to pipeline. *E.g.* a text-to-image pipeline, such as [`StableDiffusionPipeline`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) should accept among other things the text prompt to generate the image. A pure image generation pipeline, such as [DDPMPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/ddpm) on the other hand can be run without providing any inputs. To better understand what inputs can be adapted for +each pipeline, one should look directly into the respective pipeline. + +**Note**: All pipelines have PyTorch's autograd disabled by decorating the `__call__` method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should +not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community) + +## Contribution + +We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire +all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. + +- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline. +- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and +use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most +logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method. +- **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs. We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines) would be even better. +- **One-purpose-only**: Pipelines should be used for one task and one task only. Even if two tasks are very similar from a modeling point of view, *e.g.* image2image translation and in-painting, pipelines shall be used for one task only to keep them *easy-to-tweak* and *readable*. + +## Examples + +### Text-to-Image generation with Stable Diffusion + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler + +pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### Image-to-Image text-guided generation with Stable Diffusion + +The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. + +```python +import requests +from PIL import Image +from io import BytesIO + +from diffusers import StableDiffusionImg2ImgPipeline + +# load the pipeline +device = "cuda" +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + torch_dtype=torch.float16, +).to(device) + +# let's download an initial image +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((768, 512)) + +prompt = "A fantasy landscape, trending on artstation" + +images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + +images[0].save("fantasy_landscape.png") +``` +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) + +### Tweak prompts reusing seeds and latents + +You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb). + + +### In-painting using Stable Diffusion + +The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt. + +```python +import PIL +import requests +import torch +from io import BytesIO + +from diffusers import StableDiffusionInpaintPipeline + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +``` + +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/diffusers/src/diffusers/pipelines/__init__.py b/diffusers/src/diffusers/pipelines/__init__.py new file mode 100755 index 0000000..38ca7fe --- /dev/null +++ b/diffusers/src/diffusers/pipelines/__init__.py @@ -0,0 +1,745 @@ +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_k_diffusion_available, + is_librosa_available, + is_note_seq_available, + is_onnx_available, + is_sentencepiece_available, + is_torch_available, + is_torch_npu_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = { + "controlnet": [], + "controlnet_hunyuandit": [], + "controlnet_sd3": [], + "controlnet_xs": [], + "deprecated": [], + "latent_diffusion": [], + "ledits_pp": [], + "marigold": [], + "pag": [], + "stable_diffusion": [], + "stable_diffusion_xl": [], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["auto_pipeline"] = [ + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + ] + _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] + _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] + _import_structure["ddim"] = ["DDIMPipeline"] + _import_structure["ddpm"] = ["DDPMPipeline"] + _import_structure["dit"] = ["DiTPipeline"] + _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) + _import_structure["pipeline_utils"] = [ + "AudioPipelineOutput", + "DiffusionPipeline", + "StableDiffusionMixin", + "ImagePipelineOutput", + ] + _import_structure["deprecated"].extend( + [ + "PNDMPipeline", + "LDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + "KarrasVePipeline", + ] + ) +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_librosa_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) +else: + _import_structure["deprecated"].extend(["AudioDiffusionPipeline", "Mel"]) + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) +else: + _import_structure["deprecated"].extend( + [ + "MidiProcessor", + "SpectrogramDiffusionPipeline", + ] + ) + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["deprecated"].extend( + [ + "VQDiffusionPipeline", + "AltDiffusionPipeline", + "AltDiffusionImg2ImgPipeline", + "CycleDiffusionPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionModelEditingPipeline", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + ] + ) + _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] + _import_structure["animatediff"] = [ + "AnimateDiffPipeline", + "AnimateDiffControlNetPipeline", + "AnimateDiffSDXLPipeline", + "AnimateDiffSparseControlNetPipeline", + "AnimateDiffVideoToVideoPipeline", + "AnimateDiffVideoToVideoControlNetPipeline", + ] + _import_structure["flux"] = [ + "FluxControlNetPipeline", + "FluxControlNetImg2ImgPipeline", + "FluxControlNetInpaintPipeline", + "FluxImg2ImgPipeline", + "FluxInpaintPipeline", + "FluxPipeline", + "FluxFillPipeline", + ] + _import_structure["audioldm"] = ["AudioLDMPipeline"] + _import_structure["audioldm2"] = [ + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + ] + _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] + _import_structure["cogvideo"] = [ + "CogVideoXPipeline", + "CogVideoXImageToVideoPipeline", + "CogVideoXVideoToVideoPipeline", + "CogVideoXInpaintPipeline", + "CogVideoXDualInpaintPipeline", + "CogVideoXBranchInpaintPipeline", + "CogVideoXSelfGuidanceInpaintPipeline", + "CogVideoXSFTInpaintPipeline", + "CogVideoXImageToVideoInpaintPipeline", + "CogVideoXI2VDualInpaintPipeline", + "CogVideoXI2VDualInpaintAnyLPipeline", + "CogVideoXI2VInpaintAnyLPipeline", + ] + _import_structure["controlnet"].extend( + [ + "BlipDiffusionControlNetPipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + ] + ) + _import_structure["pag"].extend( + [ + "AnimateDiffPAGPipeline", + "KolorsPAGPipeline", + "HunyuanDiTPAGPipeline", + "StableDiffusion3PAGPipeline", + "StableDiffusionPAGPipeline", + "StableDiffusionControlNetPAGPipeline", + "StableDiffusionXLPAGPipeline", + "StableDiffusionXLPAGInpaintPipeline", + "StableDiffusionXLControlNetPAGImg2ImgPipeline", + "StableDiffusionXLControlNetPAGPipeline", + "StableDiffusionXLPAGImg2ImgPipeline", + "PixArtSigmaPAGPipeline", + ] + ) + _import_structure["controlnet_xs"].extend( + [ + "StableDiffusionControlNetXSPipeline", + "StableDiffusionXLControlNetXSPipeline", + ] + ) + _import_structure["controlnet_hunyuandit"].extend( + [ + "HunyuanDiTControlNetPipeline", + ] + ) + _import_structure["controlnet_sd3"].extend( + [ + "StableDiffusion3ControlNetPipeline", + "StableDiffusion3ControlNetInpaintingPipeline", + ] + ) + _import_structure["deepfloyd_if"] = [ + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + ] + _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] + _import_structure["kandinsky"] = [ + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + ] + _import_structure["kandinsky2_2"] = [ + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + ] + _import_structure["kandinsky3"] = [ + "Kandinsky3Img2ImgPipeline", + "Kandinsky3Pipeline", + ] + _import_structure["latent_consistency_models"] = [ + "LatentConsistencyModelImg2ImgPipeline", + "LatentConsistencyModelPipeline", + ] + _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) + _import_structure["ledits_pp"].extend( + [ + "LEditsPPPipelineStableDiffusion", + "LEditsPPPipelineStableDiffusionXL", + ] + ) + _import_structure["latte"] = ["LattePipeline"] + _import_structure["lumina"] = ["LuminaText2ImgPipeline"] + _import_structure["marigold"].extend( + [ + "MarigoldDepthPipeline", + "MarigoldNormalsPipeline", + ] + ) + _import_structure["musicldm"] = ["MusicLDMPipeline"] + _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] + _import_structure["pia"] = ["PIAPipeline"] + _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] + _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] + _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] + _import_structure["stable_audio"] = [ + "StableAudioProjectionModel", + "StableAudioPipeline", + ] + _import_structure["stable_cascade"] = [ + "StableCascadeCombinedPipeline", + "StableCascadeDecoderPipeline", + "StableCascadePriorPipeline", + ] + _import_structure["stable_diffusion"].extend( + [ + "CLIPImageProjection", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionPipeline", + "StableDiffusionUpscalePipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "StableDiffusionLDM3DPipeline", + ] + ) + _import_structure["aura_flow"] = ["AuraFlowPipeline"] + _import_structure["stable_diffusion_3"] = [ + "StableDiffusion3Pipeline", + "StableDiffusion3Img2ImgPipeline", + "StableDiffusion3InpaintPipeline", + ] + _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] + _import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] + _import_structure["stable_diffusion_gligen"] = [ + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + ] + _import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"] + _import_structure["stable_diffusion_xl"].extend( + [ + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + ] + ) + _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + _import_structure["stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] + _import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] + _import_structure["t2i_adapter"] = [ + "StableDiffusionAdapterPipeline", + "StableDiffusionXLAdapterPipeline", + ] + _import_structure["text_to_video_synthesis"] = [ + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "TextToVideoZeroSDXLPipeline", + "VideoToVideoSDPipeline", + ] + _import_structure["i2vgen_xl"] = ["I2VGenXLPipeline"] + _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] + _import_structure["unidiffuser"] = [ + "ImageTextPipelineOutput", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + ] + _import_structure["wuerstchen"] = [ + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) +else: + _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) +else: + _import_structure["stable_diffusion"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) + +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import ( + dummy_torch_and_transformers_and_k_diffusion_objects, + ) + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) +else: + _import_structure["stable_diffusion_k_diffusion"] = [ + "StableDiffusionKDiffusionPipeline", + "StableDiffusionXLKDiffusionPipeline", + ] + +try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import ( + dummy_torch_and_transformers_and_sentencepiece_objects, + ) + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) +else: + _import_structure["kolors"] = [ + "KolorsPipeline", + "KolorsImg2ImgPipeline", + ] + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) +else: + _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) + _import_structure["stable_diffusion"].extend( + [ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ] + ) + _import_structure["stable_diffusion_xl"].extend( + [ + "FlaxStableDiffusionXLPipeline", + ] + ) + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + + else: + from .auto_pipeline import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + ) + from .consistency_models import ConsistencyModelPipeline + from .dance_diffusion import DanceDiffusionPipeline + from .ddim import DDIMPipeline + from .ddpm import DDPMPipeline + from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline + from .dit import DiTPipeline + from .latent_diffusion import LDMSuperResolutionPipeline + from .pipeline_utils import ( + AudioPipelineOutput, + DiffusionPipeline, + ImagePipelineOutput, + StableDiffusionMixin, + ) + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_librosa_objects import * + else: + from .deprecated import AudioDiffusionPipeline, Mel + + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_objects import * + else: + from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline + from .animatediff import ( + AnimateDiffControlNetPipeline, + AnimateDiffPipeline, + AnimateDiffSDXLPipeline, + AnimateDiffSparseControlNetPipeline, + AnimateDiffVideoToVideoControlNetPipeline, + AnimateDiffVideoToVideoPipeline, + ) + from .audioldm import AudioLDMPipeline + from .audioldm2 import ( + AudioLDM2Pipeline, + AudioLDM2ProjectionModel, + AudioLDM2UNet2DConditionModel, + ) + from .aura_flow import AuraFlowPipeline + from .blip_diffusion import BlipDiffusionPipeline + from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogVideoXInpaintPipeline, CogVideoXDualInpaintPipeline, CogVideoXSelfGuidanceInpaintPipeline, CogVideoXSFTInpaintPipeline, CogVideoXImageToVideoInpaintPipeline, CogVideoXI2VDualInpaintPipeline, CogVideoXI2VDualInpaintAnyLPipeline, CogVideoXI2VInpaintAnyLPipeline + from .controlnet import ( + BlipDiffusionControlNetPipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetImg2ImgPipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLControlNetPipeline, + ) + from .controlnet_hunyuandit import ( + HunyuanDiTControlNetPipeline, + ) + from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline + from .controlnet_xs import ( + StableDiffusionControlNetXSPipeline, + StableDiffusionXLControlNetXSPipeline, + ) + from .deepfloyd_if import ( + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ) + from .deprecated import ( + AltDiffusionImg2ImgPipeline, + AltDiffusionPipeline, + CycleDiffusionPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionModelEditingPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPix2PixZeroPipeline, + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + VQDiffusionPipeline, + ) + from .flux import ( + FluxControlNetImg2ImgPipeline, + FluxControlNetInpaintPipeline, + FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, + FluxPipeline, + FluxFillPipeline, + ) + from .hunyuandit import HunyuanDiTPipeline + from .i2vgen_xl import I2VGenXLPipeline + from .kandinsky import ( + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyImg2ImgPipeline, + KandinskyInpaintCombinedPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + ) + from .kandinsky2_2 import ( + KandinskyV22CombinedPipeline, + KandinskyV22ControlnetImg2ImgPipeline, + KandinskyV22ControlnetPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintCombinedPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorEmb2EmbPipeline, + KandinskyV22PriorPipeline, + ) + from .kandinsky3 import ( + Kandinsky3Img2ImgPipeline, + Kandinsky3Pipeline, + ) + from .latent_consistency_models import ( + LatentConsistencyModelImg2ImgPipeline, + LatentConsistencyModelPipeline, + ) + from .latent_diffusion import LDMTextToImagePipeline + from .latte import LattePipeline + from .ledits_pp import ( + LEditsPPDiffusionPipelineOutput, + LEditsPPInversionPipelineOutput, + LEditsPPPipelineStableDiffusion, + LEditsPPPipelineStableDiffusionXL, + ) + from .lumina import LuminaText2ImgPipeline + from .marigold import ( + MarigoldDepthPipeline, + MarigoldNormalsPipeline, + ) + from .musicldm import MusicLDMPipeline + from .pag import ( + AnimateDiffPAGPipeline, + HunyuanDiTPAGPipeline, + KolorsPAGPipeline, + PixArtSigmaPAGPipeline, + StableDiffusion3PAGPipeline, + StableDiffusionControlNetPAGPipeline, + StableDiffusionPAGPipeline, + StableDiffusionXLControlNetPAGImg2ImgPipeline, + StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLPAGImg2ImgPipeline, + StableDiffusionXLPAGInpaintPipeline, + StableDiffusionXLPAGPipeline, + ) + from .paint_by_example import PaintByExamplePipeline + from .pia import PIAPipeline + from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline + from .semantic_stable_diffusion import SemanticStableDiffusionPipeline + from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline + from .stable_audio import StableAudioPipeline, StableAudioProjectionModel + from .stable_cascade import ( + StableCascadeCombinedPipeline, + StableCascadeDecoderPipeline, + StableCascadePriorPipeline, + ) + from .stable_diffusion import ( + CLIPImageProjection, + StableDiffusionDepth2ImgPipeline, + StableDiffusionImageVariationPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, + StableDiffusionPipeline, + StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + from .stable_diffusion_3 import ( + StableDiffusion3Img2ImgPipeline, + StableDiffusion3InpaintPipeline, + StableDiffusion3Pipeline, + ) + from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline + from .stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline + from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline + from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .stable_diffusion_sag import StableDiffusionSAGPipeline + from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPipeline, + ) + from .stable_video_diffusion import StableVideoDiffusionPipeline + from .t2i_adapter import ( + StableDiffusionAdapterPipeline, + StableDiffusionXLAdapterPipeline, + ) + from .text_to_video_synthesis import ( + TextToVideoSDPipeline, + TextToVideoZeroPipeline, + TextToVideoZeroSDXLPipeline, + VideoToVideoSDPipeline, + ) + from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ( + ImageTextPipelineOutput, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, + ) + from .wuerstchen import ( + WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, + WuerstchenPriorPipeline, + ) + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_onnx_objects import * # noqa F403 + + else: + from .onnx_utils import OnnxRuntimeModel + + try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_onnx_objects import * + else: + from .stable_diffusion import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionPipeline, + OnnxStableDiffusionUpscalePipeline, + StableDiffusionOnnxPipeline, + ) + + try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * + else: + from .stable_diffusion_k_diffusion import ( + StableDiffusionKDiffusionPipeline, + StableDiffusionXLKDiffusionPipeline, + ) + + try: + if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_sentencepiece_objects import * + else: + from .kolors import ( + KolorsImg2ImgPipeline, + KolorsPipeline, + ) + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 + else: + from .pipeline_flax_utils import FlaxDiffusionPipeline + + try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_and_transformers_objects import * + else: + from .controlnet import FlaxStableDiffusionControlNetPipeline + from .stable_diffusion import ( + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + ) + from .stable_diffusion_xl import ( + FlaxStableDiffusionXLPipeline, + ) + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + + else: + from .deprecated import ( + MidiProcessor, + SpectrogramDiffusionPipeline, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/amused/__init__.py b/diffusers/src/diffusers/pipelines/amused/__init__.py new file mode 100755 index 0000000..3c4d07a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/amused/__init__.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, + ) + + _dummy_objects.update( + { + "AmusedPipeline": AmusedPipeline, + "AmusedImg2ImgPipeline": AmusedImg2ImgPipeline, + "AmusedInpaintPipeline": AmusedInpaintPipeline, + } + ) +else: + _import_structure["pipeline_amused"] = ["AmusedPipeline"] + _import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"] + _import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedPipeline, + ) + else: + from .pipeline_amused import AmusedPipeline + from .pipeline_amused_img2img import AmusedImg2ImgPipeline + from .pipeline_amused_inpaint import AmusedInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/amused/pipeline_amused.py b/diffusers/src/diffusers/pipelines/amused/pipeline_amused.py new file mode 100755 index 0000000..a8c24b0 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/amused/pipeline_amused.py @@ -0,0 +1,328 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedPipeline + + >>> pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class AmusedPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.IntTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.IntTensor`, *optional*): + Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image + gneration. If not provided, the starting latents will be completely masked. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if height is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + + if width is None: + width = self.transformer.config.sample_size * self.vae_scale_factor + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if latents is None: + latents = torch.full( + shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device + ) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + + num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep in enumerate(self.scheduler.timesteps): + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/diffusers/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/diffusers/src/diffusers/pipelines/amused/pipeline_amused_img2img.py new file mode 100755 index 0000000..c74275b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/amused/pipeline_amused_img2img.py @@ -0,0 +1,349 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedImg2ImgPipeline.from_pretrained( + ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "winter mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> image = pipe(prompt, input_image).images[0] + ``` +""" + + +class AmusedImg2ImgPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + image: PipelineImageInput = None, + strength: float = 0.5, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.5): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + latents = self.scheduler.add_noise( + latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator + ) + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/diffusers/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/diffusers/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py new file mode 100755 index 0000000..24801e0 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py @@ -0,0 +1,380 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedInpaintPipeline.from_pretrained( + ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "fall mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> mask = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ... ) + ... .resize((512, 512)) + ... .convert("L") + ... ) + >>> pipe(prompt, input_image, mask).images[0].save("out.png") + ``` +""" + + +class AmusedInpaintPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + do_resize=True, + ) + self.scheduler.register_to_config(masking_schedule="linear") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + strength: float = 1.0, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + + mask = self.mask_processor.preprocess( + mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor + ) + mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) + latents[mask] = self.scheduler.config.mask_token_id + + starting_mask_ratio = mask.sum() / latents.numel() + + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + starting_mask_ratio=starting_mask_ratio, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/diffusers/src/diffusers/pipelines/animatediff/__init__.py b/diffusers/src/diffusers/pipelines/animatediff/__init__.py new file mode 100755 index 0000000..d916abf --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"] + _import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"] + _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"] + _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"] + _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] + _import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_animatediff import AnimateDiffPipeline + from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline + from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline + from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline + from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline + from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline + from .pipeline_output import AnimateDiffPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff.py new file mode 100755 index 0000000..cb6f50f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -0,0 +1,860 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") + >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) + >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) + >>> output = pipe(prompt="A corgi walking in the park") + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +class AnimateDiffPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, +): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_chunk_size: int = 16, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + decode_chunk_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py new file mode 100755 index 0000000..5357d6d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -0,0 +1,1100 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..controlnet.multicontrolnet import MultiControlNetModel +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... AnimateDiffControlNetPipeline, + ... AutoencoderKL, + ... ControlNetModel, + ... MotionAdapter, + ... LCMScheduler, + ... ) + >>> from diffusers.utils import export_to_gif, load_video + + >>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet + >>> # HF maintains just the right package for it: `pip install controlnet_aux` + >>> from controlnet_aux.processor import ZoeDetector + + >>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file + >>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained() + >>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16) + + >>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3) + >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") + + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + >>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", + ... motion_adapter=motion_adapter, + ... controlnet=controlnet, + ... vae=vae, + ... ).to(device="cuda", dtype=torch.float16) + >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + >>> pipe.load_lora_weights( + ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" + ... ) + >>> pipe.set_adapters(["lcm-lora"], [0.8]) + + >>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" + ... ) + >>> conditioning_frames = [] + + >>> with pipe.progress_bar(total=len(video)) as progress_bar: + ... for frame in video: + ... conditioning_frames.append(depth_detector(frame)) + ... progress_bar.update() + + >>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality" + >>> negative_prompt = "bad quality, worst quality" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=len(video), + ... num_inference_steps=10, + ... guidance_scale=2.0, + ... conditioning_frames=conditioning_frames, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> export_to_gif(video, "animatediff_controlnet.gif", fps=8) + ``` +""" + + +class AnimateDiffControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, +): + r""" + Pipeline for text-to-video generation with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + motion_adapter: MotionAdapter, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + num_frames, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + video=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(video, list): + raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}") + if len(video) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(video, list) or not isinstance(video[0], list): + raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}") + if len(video[0]) != num_frames: + raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}") + if any(len(img) != len(video[0]) for img in video): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_video( + self, + video, + width, + height, + batch_size, + num_videos_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + video = self.control_video_processor.preprocess_video(video, height=height, width=width).to( + dtype=torch.float32 + ) + video = video.permute(0, 2, 1, 3, 4).flatten(0, 1) + video_batch_size = video.shape[0] + + if video_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_videos_per_prompt + + video = video.repeat_interleave(repeat_by, dim=0) + video = video.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + video = torch.cat([video] * 2) + + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[PipelineImageInput] = None, + conditioning_frames: Optional[List[PipelineImageInput]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_chunk_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple + ControlNets are specified, images must be passed as a list such that each element of the list can be + correctly batched for input to a single ControlNet. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + num_frames=num_frames, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=conditioning_frames, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + if isinstance(controlnet, ControlNetModel): + conditioning_frames = self.prepare_video( + video=conditioning_frames, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + cond_prepared_videos = [] + for frame_ in conditioning_frames: + prepared_video = self.prepare_video( + video=frame_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + cond_prepared_videos.append(prepared_video) + conditioning_frames = cond_prepared_videos + else: + assert False + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + control_model_input = torch.transpose(control_model_input, 1, 2) + control_model_input = control_model_input.reshape( + (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=conditioning_frames, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py new file mode 100755 index 0000000..e531c91 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -0,0 +1,1276 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, MotionAdapter, UNet2DConditionModel, UNetMotionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.models import MotionAdapter + >>> from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained( + ... "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16 + ... ) + + >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" + >>> scheduler = DDIMScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... clip_sample=False, + ... timestep_spacing="linspace", + ... beta_schedule="linear", + ... steps_offset=1, + ... ) + >>> pipe = AnimateDiffSDXLPipeline.from_pretrained( + ... model_id, + ... motion_adapter=adapter, + ... scheduler=scheduler, + ... torch_dtype=torch.float16, + ... variant="fp16", + ... ).to("cuda") + + >>> # enable memory savings + >>> pipe.enable_vae_slicing() + >>> pipe.enable_vae_tiling() + + >>> output = pipe( + ... prompt="a panda surfing in the ocean, realistic, high quality", + ... negative_prompt="low quality, worst quality", + ... num_inference_steps=20, + ... guidance_scale=8, + ... width=1024, + ... height=1024, + ... num_frames=16, + ... ) + + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffSDXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + FreeInitMixin, +): + r""" + Pipeline for text-to-video generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_frames: int = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_frames: + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the + `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 8. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_videos_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + + # 9. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + progress_bar.update() + + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # 10. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py new file mode 100755 index 0000000..8b037cd --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -0,0 +1,1010 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.controlnet_sparsectrl import SparseControlNetModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AnimateDiffSparseControlNetPipeline + >>> from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel + >>> from diffusers.schedulers import DPMSolverMultistepScheduler + >>> from diffusers.utils import export_to_gif, load_image + + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3" + >>> controlnet_id = "guoyww/animatediff-sparsectrl-scribble" + >>> lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3" + >>> vae_id = "stabilityai/sd-vae-ft-mse" + >>> device = "cuda" + + >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device) + >>> controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device) + >>> vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device) + >>> scheduler = DPMSolverMultistepScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... beta_schedule="linear", + ... algorithm_type="dpmsolver++", + ... use_karras_sigmas=True, + ... ) + >>> pipe = AnimateDiffSparseControlNetPipeline.from_pretrained( + ... model_id, + ... motion_adapter=motion_adapter, + ... controlnet=controlnet, + ... vae=vae, + ... scheduler=scheduler, + ... torch_dtype=torch.float16, + ... ).to(device) + >>> pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora") + >>> pipe.fuse_lora(lora_scale=1.0) + + >>> prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" + >>> negative_prompt = "low quality, worst quality, letterboxed" + + >>> image_files = [ + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png", + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png", + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png", + ... ] + >>> condition_frame_indices = [0, 8, 15] + >>> conditioning_frames = [load_image(img_file) for img_file in image_files] + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=25, + ... conditioning_frames=conditioning_frames, + ... controlnet_conditioning_scale=1.0, + ... controlnet_frame_indices=condition_frame_indices, + ... generator=torch.Generator().manual_seed(1337), + ... ).frames[0] + >>> export_to_gif(video, "output.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class AnimateDiffSparseControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, +): + r""" + Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls + to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + motion_adapter: MotionAdapter, + controlnet: SparseControlNetModel, + scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + controlnet_conditioning_scale: float = 1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + # check `image` + if ( + isinstance(self.controlnet, SparseControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, SparseControlNetModel) + ): + if isinstance(image, list): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, SparseControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, SparseControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image(self, image, width, height, device, dtype): + image = self.control_image_processor.preprocess(image, height=height, width=width) + controlnet_images = image.unsqueeze(0).to(device, dtype) + batch_size, num_frames, channels, height, width = controlnet_images.shape + + # TODO: remove below line + assert controlnet_images.min() >= 0 and controlnet_images.max() <= 1 + + if self.controlnet.use_simplified_condition_embedding: + controlnet_images = controlnet_images.reshape(batch_size * num_frames, channels, height, width) + controlnet_images = 2 * controlnet_images - 1 + conditioning_frames = retrieve_latents(self.vae.encode(controlnet_images)) * self.vae.config.scaling_factor + conditioning_frames = conditioning_frames.reshape( + batch_size, num_frames, 4, height // self.vae_scale_factor, width // self.vae_scale_factor + ) + else: + conditioning_frames = controlnet_images + + conditioning_frames = conditioning_frames.permute(0, 2, 1, 3, 4) # [b, c, f, h, w] + return conditioning_frames + + def prepare_sparse_control_conditioning( + self, + conditioning_frames: torch.Tensor, + num_frames: int, + controlnet_frame_indices: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert conditioning_frames.shape[2] >= len(controlnet_frame_indices) + + batch_size, channels, _, height, width = conditioning_frames.shape + controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width), dtype=dtype, device=device) + controlnet_cond_mask = torch.zeros((batch_size, 1, num_frames, height, width), dtype=dtype, device=device) + controlnet_cond[:, :, controlnet_frame_indices] = conditioning_frames[:, :, : len(controlnet_frame_indices)] + controlnet_cond_mask[:, :, controlnet_frame_indices] = 1 + + return controlnet_cond, controlnet_cond_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + conditioning_frames: Optional[List[PipelineImageInput]] = None, + output_type: str = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_frame_indices: List[int] = [0], + guess_mode: bool = False, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`List[PipelineImageInput]`, *optional*): + The SparseControlNet input to provide guidance to the `unet` for generation. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + controlnet_frame_indices (`List[int]`): + The indices where the conditioning frames must be applied for generation. Multiple frames can be + provided to guide the model to generate similar structure outputs, where the `unet` can + "fill-in-the-gaps" for interpolation videos, or a single frame could be provided for general expected + structure. Must have the same length as `conditioning_frames`. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + image=conditioning_frames, + controlnet_conditioning_scale=controlnet_conditioning_scale, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, SparseControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 4. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 5. Prepare controlnet conditioning + conditioning_frames = self.prepare_image(conditioning_frames, width, height, device, controlnet.dtype) + controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning( + conditioning_frames, num_frames, controlnet_frame_indices, device, controlnet.dtype + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 10. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer SparseControlNetModel only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_cond, + conditioning_mask=controlnet_cond_mask, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 11. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 12. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py new file mode 100755 index 0000000..1ebe2b9 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -0,0 +1,1059 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import imageio + >>> import requests + >>> import torch + >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter + >>> from diffusers.utils import export_to_gif + >>> from io import BytesIO + >>> from PIL import Image + + >>> adapter = MotionAdapter.from_pretrained( + ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16 + ... ) + >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter + ... ).to("cuda") + >>> pipe.scheduler = DDIMScheduler( + ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace" + ... ) + + + >>> def load_video(file_path: str): + ... images = [] + + ... if file_path.startswith(("http://", "https://")): + ... # If the file_path is a URL + ... response = requests.get(file_path) + ... response.raise_for_status() + ... content = BytesIO(response.content) + ... vid = imageio.get_reader(content) + ... else: + ... # Assuming it's a local file path + ... vid = imageio.get_reader(file_path) + + ... for frame in vid: + ... pil_image = Image.fromarray(frame) + ... images.append(pil_image) + + ... return images + + + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" + ... ) + >>> output = pipe( + ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5 + ... ) + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffVideoToVideoPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, +): + r""" + Pipeline for video-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor: + latents = [] + for i in range(0, len(video), decode_chunk_size): + batch_video = video[i : i + decode_chunk_size] + batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) + latents.append(batch_video) + return torch.cat(latents) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + height, + width, + video=None, + latents=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + decode_chunk_size: int = 16, + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + video = video.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0) + for i in range(batch_size) + ] + else: + init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video] + + init_latents = torch.cat(init_latents, dim=0) + + # restore vae to original dtype + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + error_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Please make sure to update your script to pass as many initial images as text prompts" + ) + raise ValueError(error_message) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4) + else: + if shape != latents.shape: + # [B, C, F, H, W] + raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + + latents = latents.to(device, dtype=dtype) + + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + video: List[List[PipelineImageInput]] = None, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + enforce_inference_steps: bool = False, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.5, + strength: float = 0.8, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_chunk_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + decode_chunk_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. + + Examples: + + Returns: + [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=video, + latents=latents, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.dtype + + # 3. Prepare timesteps + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, device, timesteps, sigmas + ) + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + + # 4. Prepare latent variables + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + # Move the number of frames before the number of channels. + video = video.permute(0, 2, 1, 3, 4) + video = video.to(device=device, dtype=dtype) + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + video=video, + height=height, + width=width, + num_channels_latents=num_channels_latents, + batch_size=batch_size * num_videos_per_prompt, + timestep=latent_timestep, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, + ) + + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + num_inference_steps = len(timesteps) + # make sure to readjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 9. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 10. Post-processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py new file mode 100755 index 0000000..1d26f95 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -0,0 +1,1341 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..controlnet.multicontrolnet import MultiControlNetModel +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from tqdm.auto import tqdm + + >>> from diffusers import AnimateDiffVideoToVideoControlNetPipeline + >>> from diffusers.utils import export_to_gif, load_video + >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16 + ... ) + >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + + >>> pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", + ... motion_adapter=motion_adapter, + ... controlnet=controlnet, + ... vae=vae, + ... ).to(device="cuda", dtype=torch.float16) + + >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + >>> pipe.load_lora_weights( + ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" + ... ) + >>> pipe.set_adapters(["lcm-lora"], [0.8]) + + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif" + ... ) + >>> video = [frame.convert("RGB") for frame in video] + + >>> from controlnet_aux.processor import OpenposeDetector + + >>> open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") + >>> for frame in tqdm(video): + ... conditioning_frames.append(open_pose(frame)) + + >>> prompt = "astronaut in space, dancing" + >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" + + >>> strength = 0.8 + >>> with torch.inference_mode(): + ... video = pipe( + ... video=video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=10, + ... guidance_scale=2.0, + ... controlnet_conditioning_scale=0.75, + ... conditioning_frames=conditioning_frames, + ... strength=strength, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> video = [frame.resize(conditioning_frames[0].size) for frame in video] + >>> export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffVideoToVideoControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, +): + r""" + Pipeline for video-to-video generation with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]` or `Tuple[ControlNetModel]` or `MultiControlNetModel`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_video + def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor: + latents = [] + for i in range(0, len(video), decode_chunk_size): + batch_video = video[i : i + decode_chunk_size] + batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) + latents.append(batch_video) + return torch.cat(latents) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + height, + width, + video=None, + conditioning_frames=None, + latents=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + num_frames = len(video) if latents is None else latents.shape[2] + + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(conditioning_frames, list): + raise TypeError( + f"For single controlnet, `image` must be of type `list` but got {type(conditioning_frames)}" + ) + if len(conditioning_frames) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(conditioning_frames)=}") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list): + raise TypeError( + f"For multiple controlnets: `image` must be type list of lists but got {type(conditioning_frames)=}" + ) + if len(conditioning_frames[0]) != num_frames: + raise ValueError( + f"Expected length of image sublist as {num_frames} but got {len(conditioning_frames)=}" + ) + if any(len(img) != len(conditioning_frames[0]) for img in conditioning_frames): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + decode_chunk_size: int = 16, + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + video = video.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0) + for i in range(batch_size) + ] + else: + init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video] + + init_latents = torch.cat(init_latents, dim=0) + + # restore vae to original dtype + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + error_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Please make sure to update your script to pass as many initial images as text prompts" + ) + raise ValueError(error_message) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4) + else: + if shape != latents.shape: + # [B, C, F, H, W] + raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + + latents = latents.to(device, dtype=dtype) + + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + + return latents + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet.AnimateDiffControlNetPipeline.prepare_video + def prepare_conditioning_frames( + self, + video, + width, + height, + batch_size, + num_videos_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + video = self.control_video_processor.preprocess_video(video, height=height, width=width).to( + dtype=torch.float32 + ) + video = video.permute(0, 2, 1, 3, 4).flatten(0, 1) + video_batch_size = video.shape[0] + + if video_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_videos_per_prompt + + video = video.repeat_interleave(repeat_by, dim=0) + video = video.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + video = torch.cat([video] * 2) + + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + video: List[List[PipelineImageInput]] = None, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + enforce_inference_steps: bool = False, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.5, + strength: float = 0.8, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + conditioning_frames: Optional[List[PipelineImageInput]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_chunk_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple + ControlNets are specified, images must be passed as a list such that each element of the list can be + correctly batched for input to a single ControlNet. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + decode_chunk_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. + + Examples: + + Returns: + [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=video, + conditioning_frames=conditioning_frames, + latents=latents, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.dtype + + # 3. Prepare timesteps + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, device, timesteps, sigmas + ) + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + + # 4. Prepare latent variables + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + # Move the number of frames before the number of channels. + video = video.permute(0, 2, 1, 3, 4) + video = video.to(device=device, dtype=dtype) + + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + video=video, + height=height, + width=width, + num_channels_latents=num_channels_latents, + batch_size=batch_size * num_videos_per_prompt, + timestep=latent_timestep, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, + ) + + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare ControlNet conditions + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + if isinstance(controlnet, ControlNetModel): + conditioning_frames = self.prepare_conditioning_frames( + video=conditioning_frames, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + cond_prepared_videos = [] + for frame_ in conditioning_frames: + prepared_video = self.prepare_conditioning_frames( + video=frame_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + cond_prepared_videos.append(prepared_video) + conditioning_frames = cond_prepared_videos + else: + assert False + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + num_inference_steps = len(timesteps) + # make sure to readjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 10. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + control_model_input = torch.transpose(control_model_input, 1, 2) + control_model_input = control_model_input.reshape( + (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=conditioning_frames, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 11. Post-processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 12. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/animatediff/pipeline_output.py b/diffusers/src/diffusers/pipelines/animatediff/pipeline_output.py new file mode 100755 index 0000000..2417223 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/animatediff/pipeline_output.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class AnimateDiffPipelineOutput(BaseOutput): + r""" + Output class for AnimateDiff pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised + PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)` + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] diff --git a/diffusers/src/diffusers/pipelines/audioldm/__init__.py b/diffusers/src/diffusers/pipelines/audioldm/__init__.py new file mode 100755 index 0000000..a002b4a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/audioldm/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) + + _dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline}) +else: + _import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) + + else: + from .pipeline_audioldm import AudioLDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/diffusers/src/diffusers/pipelines/audioldm/pipeline_audioldm.py new file mode 100755 index 0000000..105ca40 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/audioldm/pipeline_audioldm.py @@ -0,0 +1,546 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AudioLDMPipeline + >>> import torch + >>> import scipy + + >>> repo_id = "cvssp/audioldm-s-full-v2" + >>> pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" + >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] + + >>> # save the audio sample as a .wav file + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) + ``` +""" + + +class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-audio generation using AudioLDM. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.ClapTextModelWithProjection`]): + Frozen text-encoder (`ClapTextModelWithProjection`, specifically the + [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. + tokenizer ([`PreTrainedTokenizer`]): + A [`~transformers.RobertaTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`~transformers.SpeechT5HifiGan`]): + Vocoder of class `SpeechT5HifiGan`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapTextModelWithProjection, + tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def _encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLAP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + prompt_embeds = F.normalize(prompt_embeds, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + ( + bs_embed, + seq_len, + ) = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + return mel_spectrogram + + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_length_in_s: Optional[float] = None, + num_inference_steps: int = 10, + guidance_scale: float = 2.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + output_type: Optional[str] = "np", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_length_in_s (`int`, *optional*, defaults to 5.12): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 10): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 2.5): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. + + Examples: + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=None, + class_labels=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Post-processing + mel_spectrogram = self.decode_latents(latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/src/diffusers/pipelines/audioldm2/__init__.py b/diffusers/src/diffusers/pipelines/audioldm2/__init__.py new file mode 100755 index 0000000..23cd0e4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/audioldm2/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"] + _import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + from .pipeline_audioldm2 import AudioLDM2Pipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/diffusers/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py new file mode 100755 index 0000000..2af3078 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -0,0 +1,1530 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...models.activations import get_activation +from ...models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ...models.embeddings import ( + TimestepEmbedding, + Timesteps, +) +from ...models.modeling_utils import ModelMixin +from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ...models.transformers.transformer_2d import Transformer2DModel +from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D +from ...models.unets.unet_2d_condition import UNet2DConditionOutput +from ...utils import BaseOutput, is_torch_version, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token): + batch_size = hidden_states.shape[0] + + if attention_mask is not None: + # Add two more steps to attn mask + new_attn_mask_step = attention_mask.new_ones((batch_size, 1)) + attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1) + + # Add the SOS / EOS tokens at the start / end of the sequence respectively + sos_token = sos_token.expand(batch_size, 1, -1) + eos_token = eos_token.expand(batch_size, 1, -1) + hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1) + return hidden_states, attention_mask + + +@dataclass +class AudioLDM2ProjectionModelOutput(BaseOutput): + """ + Args: + Class for AudioLDM2 projection layer's outputs. + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text + encoders and subsequently concatenating them together. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + """ + + hidden_states: torch.Tensor + attention_mask: Optional[torch.LongTensor] = None + + +class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned + embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with + `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the first text encoder (CLAP). + text_encoder_1_dim (`int`): + Dimensionality of the text embeddings from the second text encoder (T5 or VITS). + langauge_model_dim (`int`): + Dimensionality of the text embeddings from the language model (GPT2). + """ + + @register_to_config + def __init__( + self, + text_encoder_dim, + text_encoder_1_dim, + langauge_model_dim, + use_learned_position_embedding=None, + max_seq_length=None, + ): + super().__init__() + # additional projection layers for each text encoder + self.projection = nn.Linear(text_encoder_dim, langauge_model_dim) + self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim) + + # learnable SOS / EOS token embeddings for each text encoder + self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim)) + self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim)) + + self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) + self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) + + self.use_learned_position_embedding = use_learned_position_embedding + + # learable positional embedding for vits encoder + if self.use_learned_position_embedding is not None: + self.learnable_positional_embedding = torch.nn.Parameter( + torch.zeros((1, text_encoder_1_dim, max_seq_length)) + ) + + def forward( + self, + hidden_states: Optional[torch.Tensor] = None, + hidden_states_1: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + attention_mask_1: Optional[torch.LongTensor] = None, + ): + hidden_states = self.projection(hidden_states) + hidden_states, attention_mask = add_special_tokens( + hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed + ) + + # Add positional embedding for Vits hidden state + if self.use_learned_position_embedding is not None: + hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1) + + hidden_states_1 = self.projection_1(hidden_states_1) + hidden_states_1, attention_mask_1 = add_special_tokens( + hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1 + ) + + # concatenate clap and t5 text encoding + hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1) + + # concatenate attention masks + if attention_mask is None and attention_mask_1 is not None: + attention_mask = attention_mask_1.new_ones((hidden_states[:2])) + elif attention_mask is not None and attention_mask_1 is None: + attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2])) + + if attention_mask is not None and attention_mask_1 is not None: + attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1) + else: + attention_mask = None + + return AudioLDM2ProjectionModelOutput( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + +class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional + self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up + to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.") + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + else: + raise ValueError( + f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2." + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + encoder_hidden_states_1: Optional[torch.Tensor] = None, + encoder_attention_mask_1: Optional[torch.Tensor] = None, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`AudioLDM2UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + encoder_hidden_states_1 (`torch.Tensor`, *optional*): + A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be + used to condition the model on a different set of embeddings to `encoder_hidden_states`. + encoder_attention_mask_1 (`torch.Tensor`, *optional*): + A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`. + If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if encoder_attention_mask_1 is not None: + encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0 + encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states_1: Optional[torch.Tensor] = None, + encoder_attention_mask_1: Optional[torch.Tensor] = None, + ): + output_states = () + num_layers = len(self.resnets) + num_attention_per_layer = len(self.attentions) // num_layers + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(num_layers): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnets[i]), + hidden_states, + temb, + **ckpt_kwargs, + ) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = self.resnets[i](hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states_1: Optional[torch.Tensor] = None, + encoder_attention_mask_1: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1) + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(len(self.resnets[1:])): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + **ckpt_kwargs, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnets[i + 1]), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = self.resnets[i + 1](hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states_1: Optional[torch.Tensor] = None, + encoder_attention_mask_1: Optional[torch.Tensor] = None, + ): + num_layers = len(self.resnets) + num_attention_per_layer = len(self.attentions) // num_layers + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(num_layers): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnets[i]), + hidden_states, + temb, + **ckpt_kwargs, + ) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = self.resnets[i](hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/diffusers/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/diffusers/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py new file mode 100755 index 0000000..b45771d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -0,0 +1,1065 @@ +# Copyright 2024 CVSSP, ByteDance and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + ClapFeatureExtractor, + ClapModel, + GPT2Model, + RobertaTokenizer, + RobertaTokenizerFast, + SpeechT5HifiGan, + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, + VitsModel, + VitsTokenizer, +) + +from ...models import AutoencoderKL +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_librosa_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + + +if is_librosa_available(): + import librosa + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> from diffusers import AudioLDM2Pipeline + + >>> repo_id = "cvssp/audioldm2" + >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_length_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> # save the best audio sample (index 0) as a .wav file + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0]) + ``` + ``` + #Using AudioLDM2 for Text To Speech + >>> import scipy + >>> import torch + >>> from diffusers import AudioLDM2Pipeline + + >>> repo_id = "anhnct/audioldm2_gigaspeech" + >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "A female reporter is speaking" + >>> transcript = "wish you have a good day" + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... transcription=transcript, + ... num_inference_steps=200, + ... audio_length_in_s=10.0, + ... num_waveforms_per_prompt=2, + ... generator=generator, + ... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS + ... ).audios + + >>> # save the best audio sample (index 0) as a .wav file + >>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0]) + ``` +""" + + +def prepare_inputs_for_generation( + inputs_embeds, + attention_mask=None, + past_key_values=None, + **kwargs, +): + if past_key_values is not None: + # only last token for inputs_embeds if past is defined in kwargs + inputs_embeds = inputs_embeds[:, -1:] + + return { + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + } + + +class AudioLDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using AudioLDM2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.ClapModel`]): + First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model + [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection), + specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The + text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to + rank generated waveforms against the text prompt by computing similarity scores. + text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]): + Second frozen text-encoder. AudioLDM2 uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use + for TTS. AudioLDM2 uses the encoder of + [Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel). + projection_model ([`AudioLDM2ProjectionModel`]): + A trained model used to linearly project the hidden-states from the first and second text encoder models + and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are + concatenated to give the input to the language model. A Learned Position Embedding for the Vits + hidden-states + language_model ([`~transformers.GPT2Model`]): + An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected + outputs from the two text encoders. + tokenizer ([`~transformers.RobertaTokenizer`]): + Tokenizer to tokenize text for the first frozen text-encoder. + tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]): + Tokenizer to tokenize text for the second frozen text-encoder. + feature_extractor ([`~transformers.ClapFeatureExtractor`]): + Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`~transformers.SpeechT5HifiGan`]): + Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapModel, + text_encoder_2: Union[T5EncoderModel, VitsModel], + projection_model: AudioLDM2ProjectionModel, + language_model: GPT2Model, + tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], + tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer], + feature_extractor: ClapFeatureExtractor, + unet: AudioLDM2UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + projection_model=projection_model, + language_model=language_model, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + feature_extractor=feature_extractor, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = [ + self.text_encoder.text_model, + self.text_encoder.text_projection, + self.text_encoder_2, + self.projection_model, + self.language_model, + self.unet, + self.vae, + self.vocoder, + self.text_encoder, + ] + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def generate_language_model( + self, + inputs_embeds: torch.Tensor = None, + max_new_tokens: int = 8, + **model_kwargs, + ): + """ + + Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs. + + Parameters: + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + The sequence used as a prompt for the generation. + max_new_tokens (`int`): + Number of new tokens to generate. + model_kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward` + function of the model. + + Return: + `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + The sequence of generated hidden-states. + """ + max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens + model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs) + for _ in range(max_new_tokens): + # prepare model inputs + model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) + + # forward pass to get next hidden states + output = self.language_model(**model_inputs, return_dict=True) + + next_hidden_states = output.last_hidden_state + + # Update the model input + inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1) + + # Update generated hidden states, model inputs, and length for next step + model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs) + + return inputs_embeds[:, -max_new_tokens:, :] + + def encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + transcription=None, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + generated_prompt_embeds: Optional[torch.Tensor] = None, + negative_generated_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + max_new_tokens: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + transcription (`str` or `List[str]`): + transcription of text to speech + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input + argument. + negative_generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention + mask will be computed from `negative_prompt` input argument. + max_new_tokens (`int`, *optional*, defaults to None): + The number of new tokens to generate with the GPT2 language model. + Returns: + prompt_embeds (`torch.Tensor`): + Text embeddings from the Flan T5 model. + attention_mask (`torch.LongTensor`): + Attention mask to be applied to the `prompt_embeds`. + generated_prompt_embeds (`torch.Tensor`): + Text embeddings generated from the GPT2 langauge model. + + Example: + + ```python + >>> import scipy + >>> import torch + >>> from diffusers import AudioLDM2Pipeline + + >>> repo_id = "cvssp/audioldm2" + >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # Get text embedding vectors + >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt( + ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs", + ... device="cuda", + ... do_classifier_free_guidance=True, + ... ) + + >>> # Pass text embeddings to pipeline for text-conditional audio generation + >>> audio = pipe( + ... prompt_embeds=prompt_embeds, + ... attention_mask=attention_mask, + ... generated_prompt_embeds=generated_prompt_embeds, + ... num_inference_steps=200, + ... audio_length_in_s=10.0, + ... ).audios[0] + + >>> # save generated audio sample + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) + ```""" + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] + is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel) + + if is_vits_text_encoder: + text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder] + else: + text_encoders = [self.text_encoder, self.text_encoder_2] + + if prompt_embeds is None: + prompt_embeds_list = [] + attention_mask_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + use_prompt = isinstance( + tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast) + ) + text_inputs = tokenizer( + prompt if use_prompt else transcription, + padding="max_length" + if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)) + else True, + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + f"The following part of your input was truncated because {text_encoder.config.model_type} can " + f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + if text_encoder.config.model_type == "clap": + prompt_embeds = text_encoder.get_text_features( + text_input_ids, + attention_mask=attention_mask, + ) + # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) + prompt_embeds = prompt_embeds[:, None, :] + # make sure that we attend to this single hidden-state + attention_mask = attention_mask.new_ones((batch_size, 1)) + elif is_vits_text_encoder: + # Add end_token_id and attention mask in the end of sequence phonemes + for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask): + for idx, phoneme_id in enumerate(text_input_id): + if phoneme_id == 0: + text_input_id[idx] = 182 + text_attention_mask[idx] = 1 + break + prompt_embeds = text_encoder( + text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1) + ) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds_list.append(prompt_embeds) + attention_mask_list.append(attention_mask) + + projection_output = self.projection_model( + hidden_states=prompt_embeds_list[0], + hidden_states_1=prompt_embeds_list[1], + attention_mask=attention_mask_list[0], + attention_mask_1=attention_mask_list[1], + ) + projected_prompt_embeds = projection_output.hidden_states + projected_attention_mask = projection_output.attention_mask + + generated_prompt_embeds = self.generate_language_model( + projected_prompt_embeds, + attention_mask=projected_attention_mask, + max_new_tokens=max_new_tokens, + ) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + attention_mask = ( + attention_mask.to(device=device) + if attention_mask is not None + else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device) + ) + generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device) + + bs_embed, seq_len, hidden_size = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size) + + # duplicate attention mask for each generation per prompt + attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len) + + bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape + # duplicate generated embeddings for each generation per prompt, using mps friendly method + generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + generated_prompt_embeds = generated_prompt_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + negative_attention_mask_list = [] + max_length = prompt_embeds.shape[1] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=tokenizer.model_max_length + if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)) + else max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + if text_encoder.config.model_type == "clap": + negative_prompt_embeds = text_encoder.get_text_features( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) + negative_prompt_embeds = negative_prompt_embeds[:, None, :] + # make sure that we attend to this single hidden-state + negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1)) + elif is_vits_text_encoder: + negative_prompt_embeds = torch.zeros( + batch_size, + tokenizer.model_max_length, + text_encoder.config.hidden_size, + ).to(dtype=self.text_encoder_2.dtype, device=device) + negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to( + dtype=self.text_encoder_2.dtype, device=device + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_attention_mask_list.append(negative_attention_mask) + + projection_output = self.projection_model( + hidden_states=negative_prompt_embeds_list[0], + hidden_states_1=negative_prompt_embeds_list[1], + attention_mask=negative_attention_mask_list[0], + attention_mask_1=negative_attention_mask_list[1], + ) + negative_projected_prompt_embeds = projection_output.hidden_states + negative_projected_attention_mask = projection_output.attention_mask + + negative_generated_prompt_embeds = self.generate_language_model( + negative_projected_prompt_embeds, + attention_mask=negative_projected_attention_mask, + max_new_tokens=max_new_tokens, + ) + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_attention_mask = ( + negative_attention_mask.to(device=device) + if negative_attention_mask is not None + else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device) + ) + negative_generated_prompt_embeds = negative_generated_prompt_embeds.to( + dtype=self.language_model.dtype, device=device + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1) + + # duplicate unconditional attention mask for each generation per prompt + negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt) + negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len) + + # duplicate unconditional generated embeddings for each generation per prompt + seq_len = negative_generated_prompt_embeds.shape[1] + negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + negative_generated_prompt_embeds = negative_generated_prompt_embeds.view( + batch_size * num_waveforms_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds]) + + return prompt_embeds, attention_mask, generated_prompt_embeds + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype): + if not is_librosa_available(): + logger.info( + "Automatic scoring of the generated audio waveforms against the input prompt text requires the " + "`librosa` package to resample the generated waveforms. Returning the audios in the order they were " + "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`." + ) + return audio + inputs = self.tokenizer(text, return_tensors="pt", padding=True) + resampled_audio = librosa.resample( + audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate + ) + inputs["input_features"] = self.feature_extractor( + list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate + ).input_features.type(dtype) + inputs = inputs.to(device) + + # compute the audio-text similarity score using the CLAP model + logits_per_text = self.text_encoder(**inputs).logits_per_text + # sort by the highest matching generations per prompt + indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt] + audio = torch.index_select(audio, 0, indices.reshape(-1).cpu()) + return audio + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + transcription=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + generated_prompt_embeds=None, + negative_generated_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave " + "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None: + raise ValueError( + "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that" + "both arguments are specified" + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if transcription is None: + if self.text_encoder_2.config.model_type == "vits": + raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription") + elif transcription is not None and ( + not isinstance(transcription, str) and not isinstance(transcription, list) + ): + raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}") + + if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None: + if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape: + raise ValueError( + "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when " + f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != " + f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}." + ) + if ( + negative_attention_mask is not None + and negative_attention_mask.shape != negative_prompt_embeds.shape[:2] + ): + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + transcription: Union[str, List[str]] = None, + audio_length_in_s: Optional[float] = None, + num_inference_steps: int = 200, + guidance_scale: float = 3.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + generated_prompt_embeds: Optional[torch.Tensor] = None, + negative_generated_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + max_new_tokens: Optional[int] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + output_type: Optional[str] = "np", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + transcription (`str` or `List[str]`, *optional*):\ + The transcript for text to speech. + audio_length_in_s (`int`, *optional*, defaults to 10.24): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 200): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 3.5): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic + scoring is performed between the generated outputs and the text prompt. This scoring ranks the + generated waveforms based on their cosine similarity with the text input in the joint text-audio + embedding space. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input + argument. + negative_generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention + mask will be computed from `negative_prompt` input argument. + max_new_tokens (`int`, *optional*, defaults to None): + Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will + be taken from the config of the model. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + transcription, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + generated_prompt_embeds, + negative_generated_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + transcription, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + generated_prompt_embeds=generated_prompt_embeds, + negative_generated_prompt_embeds=negative_generated_prompt_embeds, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, + max_new_tokens=max_new_tokens, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=generated_prompt_embeds, + encoder_hidden_states_1=prompt_embeds, + encoder_attention_mask_1=attention_mask, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + self.maybe_free_model_hooks() + + # 8. Post-processing + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + # 9. Automatic scoring + if num_waveforms_per_prompt > 1 and prompt is not None: + audio = self.score_waveforms( + text=prompt, + audio=audio, + num_waveforms_per_prompt=num_waveforms_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/src/diffusers/pipelines/aura_flow/__init__.py b/diffusers/src/diffusers/pipelines/aura_flow/__init__.py new file mode 100755 index 0000000..e1917ba --- /dev/null +++ b/diffusers/src/diffusers/pipelines/aura_flow/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_aura_flow import AuraFlowPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py new file mode 100755 index 0000000..6a86b5c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -0,0 +1,591 @@ +# Copyright 2024 AuraFlow Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from transformers import T5Tokenizer, UMT5EncoderModel + +from ...image_processor import VaeImageProcessor +from ...models import AuraFlowTransformer2DModel, AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AuraFlowPipeline + + >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("aura_flow.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AuraFlowPipeline(DiffusionPipeline): + r""" + Args: + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. AuraFlow uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`AuraFlowTransformer2DModel`]): + Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKL, + transformer: AuraFlowTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = max_sequence_length + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_input_ids = text_inputs["input_ids"] + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + prompt_embeds = self.text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + uncond_input = {k: v.to(device) for k, v in uncond_input.items()} + negative_prompt_embeds = self.text_encoder(**uncond_input)[0] + negative_prompt_attention_mask = ( + uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape) + ) + negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = 1024, + width: Optional[int] = 1024, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Determine batch size. + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + + # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) + timestep = timestep.to(latents.device, dtype=latents.dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/auto_pipeline.py b/diffusers/src/diffusers/pipelines/auto_pipeline.py new file mode 100755 index 0000000..f6186da --- /dev/null +++ b/diffusers/src/diffusers/pipelines/auto_pipeline.py @@ -0,0 +1,1102 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin +from ..utils import is_sentencepiece_available +from .aura_flow import AuraFlowPipeline +from .controlnet import ( + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetImg2ImgPipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLControlNetPipeline, +) +from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline +from .flux import ( + FluxControlNetImg2ImgPipeline, + FluxControlNetInpaintPipeline, + FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, + FluxPipeline, +) +from .hunyuandit import HunyuanDiTPipeline +from .kandinsky import ( + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyImg2ImgPipeline, + KandinskyInpaintCombinedPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, +) +from .kandinsky2_2 import ( + KandinskyV22CombinedPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintCombinedPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, +) +from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline +from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline +from .lumina import LuminaText2ImgPipeline +from .pag import ( + HunyuanDiTPAGPipeline, + PixArtSigmaPAGPipeline, + StableDiffusion3PAGPipeline, + StableDiffusionControlNetPAGPipeline, + StableDiffusionPAGPipeline, + StableDiffusionXLControlNetPAGImg2ImgPipeline, + StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLPAGImg2ImgPipeline, + StableDiffusionXLPAGInpaintPipeline, + StableDiffusionXLPAGPipeline, +) +from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline +from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline +from .stable_diffusion import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, +) +from .stable_diffusion_3 import ( + StableDiffusion3Img2ImgPipeline, + StableDiffusion3InpaintPipeline, + StableDiffusion3Pipeline, +) +from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, +) +from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline + + +AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", StableDiffusionPipeline), + ("stable-diffusion-xl", StableDiffusionXLPipeline), + ("stable-diffusion-3", StableDiffusion3Pipeline), + ("stable-diffusion-3-pag", StableDiffusion3PAGPipeline), + ("if", IFPipeline), + ("hunyuan", HunyuanDiTPipeline), + ("hunyuan-pag", HunyuanDiTPAGPipeline), + ("kandinsky", KandinskyCombinedPipeline), + ("kandinsky22", KandinskyV22CombinedPipeline), + ("kandinsky3", Kandinsky3Pipeline), + ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), + ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), + ("wuerstchen", WuerstchenCombinedPipeline), + ("cascade", StableCascadeCombinedPipeline), + ("lcm", LatentConsistencyModelPipeline), + ("pixart-alpha", PixArtAlphaPipeline), + ("pixart-sigma", PixArtSigmaPipeline), + ("stable-diffusion-pag", StableDiffusionPAGPipeline), + ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline), + ("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline), + ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline), + ("pixart-sigma-pag", PixArtSigmaPAGPipeline), + ("auraflow", AuraFlowPipeline), + ("flux", FluxPipeline), + ("flux-controlnet", FluxControlNetPipeline), + ("lumina", LuminaText2ImgPipeline), + ] +) + +AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", StableDiffusionImg2ImgPipeline), + ("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline), + ("stable-diffusion-3", StableDiffusion3Img2ImgPipeline), + ("if", IFImg2ImgPipeline), + ("kandinsky", KandinskyImg2ImgCombinedPipeline), + ("kandinsky22", KandinskyV22Img2ImgCombinedPipeline), + ("kandinsky3", Kandinsky3Img2ImgPipeline), + ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), + ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), + ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), + ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), + ("lcm", LatentConsistencyModelImg2ImgPipeline), + ("flux", FluxImg2ImgPipeline), + ("flux-controlnet", FluxControlNetImg2ImgPipeline), + ] +) + +AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", StableDiffusionInpaintPipeline), + ("stable-diffusion-xl", StableDiffusionXLInpaintPipeline), + ("stable-diffusion-3", StableDiffusion3InpaintPipeline), + ("if", IFInpaintingPipeline), + ("kandinsky", KandinskyInpaintCombinedPipeline), + ("kandinsky22", KandinskyV22InpaintCombinedPipeline), + ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), + ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), + ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), + ("flux", FluxInpaintPipeline), + ("flux-controlnet", FluxControlNetInpaintPipeline), + ] +) + +_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( + [ + ("kandinsky", KandinskyPipeline), + ("kandinsky22", KandinskyV22Pipeline), + ("wuerstchen", WuerstchenDecoderPipeline), + ("cascade", StableCascadeDecoderPipeline), + ] +) +_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( + [ + ("kandinsky", KandinskyImg2ImgPipeline), + ("kandinsky22", KandinskyV22Img2ImgPipeline), + ] +) +_AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict( + [ + ("kandinsky", KandinskyInpaintPipeline), + ("kandinsky22", KandinskyV22InpaintPipeline), + ] +) + +if is_sentencepiece_available(): + from .kolors import KolorsImg2ImgPipeline, KolorsPipeline + from .pag import KolorsPAGPipeline + + AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline + AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsImg2ImgPipeline + +SUPPORTED_TASKS_MAPPINGS = [ + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + AUTO_INPAINT_PIPELINES_MAPPING, + _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING, + _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING, + _AUTO_INPAINT_DECODER_PIPELINES_MAPPING, +] + + +def _get_connected_pipeline(pipeline_cls): + # for now connected pipelines can only be loaded from decoder pipelines, such as kandinsky-community/kandinsky-2-2-decoder + if pipeline_cls in _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING.values(): + return _get_task_class( + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False + ) + if pipeline_cls in _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING.values(): + return _get_task_class( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False + ) + if pipeline_cls in _AUTO_INPAINT_DECODER_PIPELINES_MAPPING.values(): + return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False) + + +def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True): + def get_model(pipeline_class_name): + for task_mapping in SUPPORTED_TASKS_MAPPINGS: + for model_name, pipeline in task_mapping.items(): + if pipeline.__name__ == pipeline_class_name: + return model_name + + model_name = get_model(pipeline_class_name) + + if model_name is not None: + task_class = mapping.get(model_name, None) + if task_class is not None: + return task_class + + if throw_error_if_not_exist: + raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}") + + +class AutoPipelineForText2Image(ConfigMixin): + r""" + + [`AutoPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~AutoPipelineForText2Image.from_pretrained`] or [`~AutoPipelineForText2Image.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + r""" + Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight. + + The from_pretrained() method takes care of returning the correct pipeline class instance by: + 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its + config object + 2. Find the text-to-image pipeline linked to the pipeline class using pattern matching on pipeline class + name. + + If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetPipeline`] object. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> image = pipeline(prompt).images[0] + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + } + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + + if "controlnet" in kwargs: + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") + + text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) + + kwargs = {**load_config_kwargs, **kwargs} + return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Instantiates a text-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class. + + The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-image + pipeline linked to the pipeline class using pattern matching on pipeline class name. + + All the modules the pipeline contains will be used to initialize the new pipeline without reallocating + additional memory. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pipeline (`DiffusionPipeline`): + an instantiated `DiffusionPipeline` object + + ```py + >>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image + + >>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False + ... ) + + >>> pipe_t2i = AutoPipelineForText2Image.from_pipe(pipe_i2i) + >>> image = pipe_t2i(prompt).images[0] + ``` + """ + + original_config = dict(pipeline.config) + original_cls_name = pipeline.__class__.__name__ + + # derive the pipeline class to instantiate + text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name) + + if "controlnet" in kwargs: + if kwargs["controlnet"] is not None: + to_replace = "PAGPipeline" if "PAG" in text_2_image_cls.__name__ else "Pipeline" + text_2_image_cls = _get_task_class( + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + text_2_image_cls.__name__.replace("ControlNet", "").replace(to_replace, "ControlNet" + to_replace), + ) + else: + text_2_image_cls = _get_task_class( + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + text_2_image_cls.__name__.replace("ControlNet", ""), + ) + + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + text_2_image_cls = _get_task_class( + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + text_2_image_cls.__name__.replace("PAG", "").replace("Pipeline", "PAGPipeline"), + ) + else: + text_2_image_cls = _get_task_class( + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + text_2_image_cls.__name__.replace("PAG", ""), + ) + + # define expected module and optional kwargs given the pipeline signature + expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls) + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + original_class_obj = { + k: pipeline.components[k] + for k, v in pipeline.components.items() + if k in expected_modules and k not in passed_class_obj + } + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k, v in original_config.items() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config that were not expected by original pipeline is stored as private attribute + # we will pass them as optional arguments if they can be accepted by the pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + text_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} + + # store unused config as private attribute + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": original_config[k] + for k, v in original_config.items() + if k not in text_2_image_kwargs + } + + missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys()) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {text_2_image_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" + ) + + model = text_2_image_cls(**text_2_image_kwargs) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.register_to_config(**unused_original_config) + + return model + + +class AutoPipelineForImage2Image(ConfigMixin): + r""" + + [`AutoPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~AutoPipelineForImage2Image.from_pretrained`] or [`~AutoPipelineForImage2Image.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + r""" + Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight. + + The from_pretrained() method takes care of returning the correct pipeline class instance by: + 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its + config object + 2. Find the image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class + name. + + If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetImg2ImgPipeline`] + object. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForImage2Image + + >>> pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> image = pipeline(prompt, image).images[0] + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + } + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + + # the `orig_class_name` can be: + # `- *Pipeline` (for regular text-to-image checkpoint) + # `- *Img2ImgPipeline` (for refiner checkpoint) + to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" + + if "controlnet" in kwargs: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) + + kwargs = {**load_config_kwargs, **kwargs} + return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Instantiates a image-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class. + + The from_pipe() method takes care of returning the correct pipeline class instance by finding the + image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name. + + All the modules the pipeline contains will be used to initialize the new pipeline without reallocating + additional memory. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pipeline (`DiffusionPipeline`): + an instantiated `DiffusionPipeline` object + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image + + >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False + ... ) + + >>> pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i) + >>> image = pipe_i2i(prompt, image).images[0] + ``` + """ + + original_config = dict(pipeline.config) + original_cls_name = pipeline.__class__.__name__ + + # derive the pipeline class to instantiate + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name) + + if "controlnet" in kwargs: + if kwargs["controlnet"] is not None: + to_replace = "Img2ImgPipeline" + if "PAG" in image_2_image_cls.__name__: + to_replace = "PAG" + to_replace + image_2_image_cls = _get_task_class( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + image_2_image_cls.__name__.replace("ControlNet", "").replace( + to_replace, "ControlNet" + to_replace + ), + ) + else: + image_2_image_cls = _get_task_class( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + image_2_image_cls.__name__.replace("ControlNet", ""), + ) + + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + image_2_image_cls = _get_task_class( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + image_2_image_cls.__name__.replace("PAG", "").replace("Img2ImgPipeline", "PAGImg2ImgPipeline"), + ) + else: + image_2_image_cls = _get_task_class( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + image_2_image_cls.__name__.replace("PAG", ""), + ) + + # define expected module and optional kwargs given the pipeline signature + expected_modules, optional_kwargs = image_2_image_cls._get_signature_keys(image_2_image_cls) + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + original_class_obj = { + k: pipeline.components[k] + for k, v in pipeline.components.items() + if k in expected_modules and k not in passed_class_obj + } + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k, v in original_config.items() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config attribute that were not expected by original pipeline is stored as its private attribute + # we will pass them as optional arguments if they can be accepted by the pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + image_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} + + # store unused config as private attribute + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": original_config[k] + for k, v in original_config.items() + if k not in image_2_image_kwargs + } + + missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys()) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {image_2_image_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" + ) + + model = image_2_image_cls(**image_2_image_kwargs) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.register_to_config(**unused_original_config) + + return model + + +class AutoPipelineForInpainting(ConfigMixin): + r""" + + [`AutoPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + + """ + + config_name = "model_index.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + r""" + Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight. + + The from_pretrained() method takes care of returning the correct pipeline class instance by: + 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its + config object + 2. Find the inpainting pipeline linked to the pipeline class using pattern matching on pipeline class name. + + If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetInpaintPipeline`] + object. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import AutoPipelineForInpainting + + >>> pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0] + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + } + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + + # The `orig_class_name`` can be: + # `- *InpaintPipeline` (for inpaint-specific checkpoint) + # - or *Pipeline (for regular text-to-image checkpoint) + to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + + if "controlnet" in kwargs: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) + + kwargs = {**load_config_kwargs, **kwargs} + return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Instantiates a inpainting Pytorch diffusion pipeline from another instantiated diffusion pipeline class. + + The from_pipe() method takes care of returning the correct pipeline class instance by finding the inpainting + pipeline linked to the pipeline class using pattern matching on pipeline class name. + + All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating + additional memory. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pipeline (`DiffusionPipeline`): + an instantiated `DiffusionPipeline` object + + Examples: + + ```py + >>> from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting + + >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", requires_safety_checker=False + ... ) + + >>> pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_t2i) + >>> image = pipe_inpaint(prompt, image=init_image, mask_image=mask_image).images[0] + ``` + """ + original_config = dict(pipeline.config) + original_cls_name = pipeline.__class__.__name__ + + # derive the pipeline class to instantiate + inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, original_cls_name) + + if "controlnet" in kwargs: + if kwargs["controlnet"] is not None: + inpainting_cls = _get_task_class( + AUTO_INPAINT_PIPELINES_MAPPING, + inpainting_cls.__name__.replace("ControlNet", "").replace( + "InpaintPipeline", "ControlNetInpaintPipeline" + ), + ) + else: + inpainting_cls = _get_task_class( + AUTO_INPAINT_PIPELINES_MAPPING, + inpainting_cls.__name__.replace("ControlNetInpaintPipeline", "InpaintPipeline"), + ) + + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + inpainting_cls = _get_task_class( + AUTO_INPAINT_PIPELINES_MAPPING, + inpainting_cls.__name__.replace("PAG", "").replace("InpaintPipeline", "PAGInpaintPipeline"), + ) + else: + inpainting_cls = _get_task_class( + AUTO_INPAINT_PIPELINES_MAPPING, + inpainting_cls.__name__.replace("PAGInpaintPipeline", "InpaintPipeline"), + ) + + # define expected module and optional kwargs given the pipeline signature + expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls) + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + original_class_obj = { + k: pipeline.components[k] + for k, v in pipeline.components.items() + if k in expected_modules and k not in passed_class_obj + } + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k, v in original_config.items() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config that were not expected by original pipeline is stored as private attribute + # we will pass them as optional arguments if they can be accepted by the pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + inpainting_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} + + # store unused config as private attribute + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": original_config[k] + for k, v in original_config.items() + if k not in inpainting_kwargs + } + + missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys()) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {inpainting_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" + ) + + model = inpainting_cls(**inpainting_kwargs) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.register_to_config(**unused_original_config) + + return model diff --git a/diffusers/src/diffusers/pipelines/blip_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/blip_diffusion/__init__.py new file mode 100755 index 0000000..af6c879 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/blip_diffusion/__init__.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .blip_image_processing import BlipImageProcessor + from .modeling_blip2 import Blip2QFormerModel + from .modeling_ctx_clip import ContextCLIPTextModel + from .pipeline_blip_diffusion import BlipDiffusionPipeline diff --git a/diffusers/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py b/diffusers/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py new file mode 100755 index 0000000..d92a076 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py @@ -0,0 +1,318 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for BLIP.""" + +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from transformers.utils import TensorType, is_vision_available, logging + +from diffusers.utils import numpy_to_pil + + +if is_vision_available(): + import PIL.Image + + +logger = logging.get_logger(__name__) + + +# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop +# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor +class BlipImageProcessor(BaseImageProcessor): + r""" + Constructs a BLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + do_center_crop: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self.do_center_crop = do_center_crop + + # Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + do_center_crop: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_convert_rgb: bool = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None or resample is None: + raise ValueError("Size and resample must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + if do_center_crop: + images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + return encoded_outputs + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess(self, sample: torch.Tensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + # Output_type must be 'pil' + sample = numpy_to_pil(sample) + return sample diff --git a/diffusers/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/diffusers/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py new file mode 100755 index 0000000..1be4761 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -0,0 +1,642 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers import BertTokenizer +from transformers.activations import QuickGELUActivation as QuickGELU +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig +from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Encoder, + Blip2PreTrainedModel, + Blip2QFormerAttention, + Blip2QFormerIntermediate, + Blip2QFormerOutput, +) +from transformers.pytorch_utils import apply_chunking_to_forward +from transformers.utils import ( + logging, + replace_return_docstrings, +) + + +logger = logging.get_logger(__name__) + + +# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py. +# But it doesn't support getting multimodal embeddings. So, this module can be +# replaced with a future `transformers` version supports that. +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + batch_size = embeddings.shape[0] + # repeat the query embeddings for batch size + query_embeds = query_embeds.repeat(batch_size, 1, 1) + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = embeddings.to(query_embeds.dtype) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 +class Blip2VisionEmbeddings(nn.Module): + def __init__(self, config: Blip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings +class Blip2QFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if layer_module.has_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# The layers making up the Qformer encoder +class Blip2QFormerLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = Blip2QFormerIntermediate(config) + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + self.output = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder +class ProjLayer(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12): + super().__init__() + + # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm + self.dense1 = nn.Linear(in_dim, hidden_dim) + self.act_fn = QuickGELU() + self.dense2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(drop_p) + + self.LayerNorm = nn.LayerNorm(out_dim, eps=eps) + + def forward(self, x): + x_in = x + + x = self.LayerNorm(x) + x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in + + return x + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 +class Blip2VisionModel(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2VisionConfig + + def __init__(self, config: Blip2VisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + self.embeddings = Blip2VisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = Blip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layernorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + +# Qformer model, used to get multimodal embeddings from the text and image inputs +class Blip2QFormerModel(Blip2PreTrainedModel): + """ + Querying Transformer (Q-Former), used in BLIP-2. + """ + + def __init__(self, config: Blip2Config): + super().__init__(config) + self.config = config + self.embeddings = Blip2TextEmbeddings(config.qformer_config) + self.visual_encoder = Blip2VisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + if not hasattr(config, "tokenizer") or config.tokenizer is None: + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + else: + self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right") + self.tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.proj_layer = ProjLayer( + in_dim=config.qformer_config.hidden_size, + out_dim=config.qformer_config.hidden_size, + hidden_dim=config.qformer_config.hidden_size * 4, + drop_p=0.1, + eps=1e-12, + ) + + self.encoder = Blip2QFormerEncoder(config.qformer_config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int], + device: torch.device, + has_query: bool = False, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + device (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + text_input=None, + image_input=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of: + shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and + value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are + used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape + `(batch_size, sequence_length)`. + use_cache (`bool`, `optional`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + + text = self.tokenizer(text_input, return_tensors="pt", padding=True) + text = text.to(self.device) + input_ids = text.input_ids + batch_size = input_ids.shape[0] + query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + ) + + query_length = self.query_tokens.shape[1] + + embedding_output = self.embeddings( + input_ids=input_ids, + query_embeds=self.query_tokens, + past_key_values_length=past_key_values_length, + ) + + # embedding_output = self.layernorm(query_embeds) + # embedding_output = self.dropout(embedding_output) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state + # image_embeds_frozen = torch.ones_like(image_embeds_frozen) + encoder_hidden_states = image_embeds_frozen + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + return self.proj_layer(sequence_output[:, :query_length, :]) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/diffusers/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/diffusers/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py new file mode 100755 index 0000000..d29dddf --- /dev/null +++ b/diffusers/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py @@ -0,0 +1,223 @@ +# Copyright 2024 Salesforce.com, inc. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from transformers import CLIPPreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.configuration_clip import CLIPTextConfig +from transformers.models.clip.modeling_clip import CLIPEncoder + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip +# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer +# They pass through the clip model, along with the text embeddings, and interact with them using self attention +class ContextCLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = ContextCLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + ctx_embeddings: torch.Tensor = None, + ctx_begin_pos: list = None, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + return self.text_model( + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ContextCLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = ContextCLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + ) + + bsz, seq_len = input_shape + if ctx_embeddings is not None: + seq_len += ctx_embeddings.size(1) + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=input_ids.device), + input_ids.to(torch.int).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _build_causal_attention_mask(self, bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class ContextCLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if ctx_embeddings is None: + ctx_len = 0 + else: + ctx_len = ctx_embeddings.shape[1] + + seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + # for each input embeddings, add the ctx embeddings at the correct position + input_embeds_ctx = [] + bsz = inputs_embeds.shape[0] + + if ctx_embeddings is not None: + for i in range(bsz): + cbp = ctx_begin_pos[i] + + prefix = inputs_embeds[i, :cbp] + # remove the special token embedding + suffix = inputs_embeds[i, cbp:] + + input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0)) + + inputs_embeds = torch.stack(input_embeds_ctx, dim=0) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings diff --git a/diffusers/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/diffusers/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py new file mode 100755 index 0000000..ff23247 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -0,0 +1,348 @@ +# Copyright 2024 Salesforce.com, inc. +# Copyright 2024 The HuggingFace Team. All rights reserved.# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .blip_image_processing import BlipImageProcessor +from .modeling_blip2 import Blip2QFormerModel +from .modeling_ctx_clip import ContextCLIPTextModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained( + ... "Salesforce/blipdiffusion", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> cond_subject = "dog" + >>> tgt_subject = "dog" + >>> text_prompt_input = "swimming underwater" + + >>> cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 25 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt_input, + ... cond_image, + ... cond_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionPipeline(DiffusionPipeline): + """ + Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + model_cpu_offload_seq = "qformer->text_encoder->unet->vae" + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: List[float] = None, + std: List[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt, device=None): + device = device or self._execution_device + + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: List[str], + reference_image: PIL.Image.Image, + source_subject_category: List[str], + target_subject_category: List[str], + latents: Optional[torch.Tensor] = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + neg_prompt: Optional[str] = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`List[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + source_subject_category (`List[str]`): + The source subject category. + target_subject_category (`List[str]`): + The target subject category. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt, device) + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/__init__.py b/diffusers/src/diffusers/pipelines/cogvideo/__init__.py new file mode 100755 index 0000000..763f861 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/__init__.py @@ -0,0 +1,65 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"] + _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] + _import_structure["pipeline_cogvideox_inpainting"] = ["CogVideoXInpaintPipeline"] + _import_structure["pipeline_cogvideox_inpainting_branch"] = ["CogVideoXDualInpaintPipeline"] + _import_structure["pipeline_cogvideox_inpainting_sft"] = ["CogVideoXSFTInpaintPipeline"] + _import_structure["pipeline_cogvideox_inpainting_selfguidance"] = ["CogVideoXSelfGuidanceInpaintPipeline"] + _import_structure["pipeline_cogvideox_image2video_inpainting"] = ["CogVideoXImageToVideoInpaintPipeline"] + _import_structure["pipeline_cogvideox_inpainting_i2v_branch"] = ["CogVideoXI2VDualInpaintPipeline"] + _import_structure["pipeline_cogvideox_inpainting_i2v_branch_anyl"] = ["CogVideoXI2VDualInpaintAnyLPipeline"] + _import_structure["pipeline_cogvideox_inpainting_i2v_anyl"] = ["CogVideoXI2VInpaintAnyLPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_cogvideox import CogVideoXPipeline + from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline + from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline + from .pipeline_cogvideox_inpainting import CogVideoXInpaintPipeline + from .pipeline_cogvideox_inpainting_branch import CogVideoXDualInpaintPipeline + from .pipeline_cogvideox_inpainting_selfguidance import CogVideoXSelfGuidanceInpaintPipeline + from .pipeline_cogvideox_inpainting_sft import CogVideoXSFTInpaintPipeline + from .pipeline_cogvideox_image2video_inpainting import CogVideoXImageToVideoInpaintPipeline + from .pipeline_cogvideox_inpainting_i2v_branch import CogVideoXI2VDualInpaintPipeline + from .pipeline_cogvideox_inpainting_i2v_anyl import CogVideoXI2VInpaintAnyLPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py new file mode 100755 index 0000000..02497e7 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -0,0 +1,742 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py new file mode 100755 index 0000000..a1576be --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -0,0 +1,827 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import PIL +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import CogVideoXImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> video = pipe(image, prompt, use_dynamic_cfg=True) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CogVideoXImageToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae.config.scaling_factor * image_latents + + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + video=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + latent_channels = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_inpainting.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_inpainting.py new file mode 100755 index 0000000..4ed9d93 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_inpainting.py @@ -0,0 +1,1057 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import PIL +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import CogVideoXImageToVideoInpaintPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = CogVideoXImageToVideoInpaintPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> video = pipe(image, prompt, use_dynamic_cfg=True) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CogVideoXImageToVideoInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents_original( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae.config.scaling_factor * image_latents + + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, image_latents + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + # inpainting args + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae.config.scaling_factor * image_latents + + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + if is_strength_max: + latents = noise + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma + else: + latents = self.scheduler.add_noise(video_latents, noise, timestep) + + else: + latents = latents.to(device) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + outputs = (latents, image_latents) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + video=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + # inpainting args + video: List[PIL.Image.Image] = None, + masked_video: List[PIL.Image.Image] = None, + masked_video_latents: torch.Tensor = None, + strength: float = 1.0, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, timesteps=timesteps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Prepare latents + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + # 5. Preprocess mask and video + original_video = video + init_video = self.video_processor.preprocess_video(original_video, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + + latent_channels = self.transformer.config.in_channels // 2 + return_video_latents = latent_channels == 16 + latents_outputs = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_video_latents, + ) + + if return_video_latents: + latents, image_latents, noise, video_latents = latents_outputs + else: + latents, image_latents, noise = latents_outputs + + # latents, image_latents = self.prepare_latents_original( + # image, + # batch_size * num_videos_per_prompt, + # latent_channels, + # num_frames, + # height, + # width, + # prompt_embeds.dtype, + # device, + # generator, + # latents, + # ) + + # 5. Prepare mask latent variables + mask_condition = self.masked_video_processor.preprocess_video(masked_video, height=height, width=width).to(device) + if masked_video_latents is None: + masked_video = init_video * (mask_condition < 0.5) + else: + masked_video = masked_video_latents + mask, masked_video_latents = self.prepare_mask_latents( + mask_condition, + masked_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + mask = mask.permute(0, 2, 1, 3, 4).repeat(1, 1, latent_channels, 1, 1) + + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # inpainting + init_latents_proper = video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + before_latents = latents + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting.py new file mode 100755 index 0000000..748b9ff --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting.py @@ -0,0 +1,944 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXInpaintPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video inpainting using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + video=None, + masked_video=None, + ): + if video is None: + raise ValueError( + f"Video is required for inpainting." + ) + if masked_video is None: + raise ValueError( + f"masked_video is required for inpainting." + ) + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + # inpainting args + video: List[PIL.Image.Image] = None, + masked_video: List[PIL.Image.Image] = None, + masked_video_latents: torch.Tensor = None, + strength: float = 1.0, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + video, + masked_video, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + + # 5. Preprocess mask and video + crops_coords = None + resize_mode = "default" + + original_video = video + init_video = self.video_processor.preprocess_video(original_video, height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + return_video_latents = latent_channels == 16 + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_video_latents, + ) # [B, F, C, H, W] + + if return_video_latents: + latents, noise, video_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask_condition = self.masked_video_processor.preprocess_video(masked_video, height=height, width=width) + if masked_video_latents is None: + masked_video = init_video * (mask_condition < 0.5) + else: + masked_video = masked_video_latents + mask, masked_video_latents = self.prepare_mask_latents( + mask_condition, + masked_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + mask = mask.permute(0, 2, 1, 3, 4).repeat(1, 1, latent_channels, 1, 1) + + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + latents = latents.to(device) + + # inpainting + init_latents_proper = video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + before_latents = latents + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_video_latents = callback_outputs.pop("masked_video_latents", masked_video_latents) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_branch.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_branch.py new file mode 100755 index 0000000..0a3133b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_branch.py @@ -0,0 +1,962 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.branch_cogvideox import CogvideoXBranchModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + +# import matplotlib.pyplot as plt +# import numpy as np +# import os + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXDualInpaintPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + branch: CogvideoXBranchModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, branch=branch + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + # inpainting args + video: List[PIL.Image.Image] = None, + masked_video: List[PIL.Image.Image] = None, + masked_video_latents: torch.Tensor = None, + strength: float = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + conditioning_scale: Union[float, List[float]] = 1.0, + mask_background: Optional[bool] = False, + add_first: Optional[bool] = False, + wo_text: Optional[bool] = False, + mask_add: Optional[bool] = False, + replace_gt:Optional[bool] = False, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + + # 5. Preprocess mask and video + crops_coords = None + resize_mode = "default" + + original_video = video + init_video = self.video_processor.preprocess_video(original_video, height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + return_video_latents = latent_channels == 16 + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_video_latents, + ) # [B, F, C, H, W] + + if return_video_latents: + latents, noise, video_latents = latents_outputs + else: + latents, noise = latents_outputs + + + # 7. Prepare mask latent variables + mask_condition = self.masked_video_processor.preprocess_video(masked_video, height=height, width=width) + if masked_video_latents is None: + if mask_background: + masked_video = init_video * (mask_condition >= 0.5) + else: + masked_video = init_video * (mask_condition < 0.5) + else: + masked_video = masked_video_latents + + mask, masked_video_latents = self.prepare_mask_latents( + mask_condition, + masked_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + mask = mask.permute(0, 2, 1, 3, 4).repeat(1, 1, latent_channels, 1, 1) + + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # branch block + # mask = 1.0 - mask + latent_branch_input = torch.cat([masked_video_latents, mask[:, :, :1, :, :]], dim=-3) + + + branch_block_samples = self.branch( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + branch_cond=latent_branch_input, + branch_mode=control_mode, + conditioning_scale=conditioning_scale, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + wo_text=wo_text, + return_dict=False, + )[0] + branch_block_samples = [block_sample.to(dtype=prompt_embeds.dtype) for block_sample in branch_block_samples] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + branch_block_samples=branch_block_samples, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + add_first=add_first, + branch_block_masks=mask[:, :, :1, :, :] if mask_add else None, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # inpainting + latents = latents.to(device) + if replace_gt: + init_latents_proper = video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + before_latents = latents + if mask_background: + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_video_latents = callback_outputs.pop("masked_video_latents", masked_video_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_anyl.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_anyl.py new file mode 100755 index 0000000..eed9ef2 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_anyl.py @@ -0,0 +1,1075 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PipelineImageInput +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.branch_cogvideox import CogvideoXBranchModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + +# import matplotlib.pyplot as plt +# import numpy as np +# import os + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXI2VInpaintAnyLPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + branch ([`CogvideoXBranchModel`]): + A branch model to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + branch: CogvideoXBranchModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, branch=branch + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + image: torch.Tensor = None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if len(image.shape) == 4: + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae.config.scaling_factor * image_latents + elif len(image.shape) == 5: + image_latents = image + else: + raise ValueError(f"image shape is not valid: {image.shape}") + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents, image_latents) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + video: List[PIL.Image.Image] = None, + masks: List[PIL.Image.Image] = None, + masked_video_latents: torch.Tensor = None, + strength: float = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + conditioning_scale: Union[float, List[float]] = 1.0, + mask_background: Optional[bool] = False, + add_first: Optional[bool] = False, + wo_text: Optional[bool] = False, + id_pool_resample_learnable: Optional[bool] = False, + mask_add: Optional[bool] = False, + replace_gt:Optional[bool] = False, + stride: int = 24, + prev_clip_weight: Optional[float] = 0.0, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(video, (list, tuple)): + total_frames = len(video) + else: + total_frames = video.shape[1] + + n_windows = (total_frames - num_frames) // stride + 1 + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs + self.check_inputs( + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, timesteps=timesteps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + + + if stride < num_frames: + num_frame_latents = ((num_frames - 1) // self.vae_scale_factor_temporal + 1) * n_windows - (n_windows - 1) * ((num_frames - stride) // self.vae_scale_factor_temporal + 1) + elif stride == num_frames: + num_frame_latents = ((num_frames - 1) // self.vae_scale_factor_temporal) * n_windows + 1 + else: + raise ValueError(f"stride: {stride}, num_frames: {num_frames}") + latent_channels = self.transformer.config.in_channels // 2 + frame_counts = torch.zeros(num_frame_latents, device=device) + frame_accumulator = torch.zeros( + (batch_size, num_frame_latents, latent_channels, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial), + device=device, + dtype=prompt_embeds.dtype + ) + + prev_window_states_dict = {} + + for window_idx in range(n_windows): + start_frame = window_idx * stride + end_frame = start_frame + num_frames + + if isinstance(video, (list, tuple)): + window_video = video[start_frame:end_frame] + else: + window_video = video[:, start_frame:end_frame] + + if isinstance(masks, (list, tuple)): + window_mask = masks[start_frame:end_frame] + else: + window_mask = masks[:, start_frame:end_frame] + + # 5. Prepare latents + window_video = self.video_processor.preprocess_video(window_video, height=height, width=width) + window_video = window_video.to(dtype=torch.float32) + if window_idx == 0: + image_ = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + else: + + if -int((num_frames - stride) // self.vae_scale_factor_temporal) < 0: + image_ = latents[:, -int((num_frames - stride) // self.vae_scale_factor_temporal) - 1:-int((num_frames - stride) // self.vae_scale_factor_temporal) + 1 - 1] + # print(f"image was set by the frame {-int((num_frames - stride) // self.vae_scale_factor_temporal) - 1} of the previous window latents, num_frames: {num_frames}, stride: {stride}") + else: + image_ = latents[:, -1:] + # print(f"image was set by the last frame of the previous window latents") + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents=None, # must be None for all windows + video=window_video, + image=image_, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=latent_channels == 16, + ) + + if latent_channels == 16: + latents, image_latents, noise, video_latents = latents_outputs + else: + latents, image_latents, noise = latents_outputs + + # 7. Prepare mask latents + mask_condition = self.masked_video_processor.preprocess_video(window_mask, height=height, width=width) + if masked_video_latents is not None: + masked_video_latents = None + if masked_video_latents is None: + if mask_background: + masked_video = window_video * (mask_condition >= 0.5) + else: + masked_video = window_video * (mask_condition < 0.5) + else: + masked_video = masked_video_latents[:, start_frame:end_frame] + + mask, masked_video_latents = self.prepare_mask_latents( + mask_condition, + masked_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + mask = mask.permute(0, 2, 1, 3, 4).repeat(1, 1, latent_channels, 1, 1) + + # 8. Prepare extra step kwargs + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + old_pred_original_sample = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_video_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_video_input = self.scheduler.scale_model_input(latent_video_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_video_input, latent_image_input], dim=2) + timestep = t.expand(latent_model_input.shape[0]) + + # branch block + latent_branch_input = torch.cat([masked_video_latents, mask[:, :, :1, :, :]], dim=-3) + + branch_block_samples = self.branch( + hidden_states=latent_video_input, + encoder_hidden_states=prompt_embeds, + branch_cond=latent_branch_input, + branch_mode=control_mode, + conditioning_scale=conditioning_scale, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + wo_text=wo_text, + return_dict=False, + )[0] + + branch_block_samples = [block_sample.to(dtype=prompt_embeds.dtype) for block_sample in branch_block_samples] + + current_attention_kwargs = attention_kwargs.copy() if attention_kwargs else {} + noise_pred, hidden_states_list = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=current_attention_kwargs, + return_hidden_states=True, + return_resample_mask=False, + return_dict=False, + ) + noise_pred = noise_pred.float() + + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # inpainting + latents = latents.to(device) + + if replace_gt: + init_latents_proper = video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + + if mask_background: + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_video_latents = callback_outputs.pop("masked_video_latents", masked_video_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + for i in range(latents.shape[1]): + compressed_start_frame = window_idx * latents.shape[1] + if window_idx > 0 and stride < num_frames: + compressed_start_frame -= (int((num_frames - stride) // self.vae_scale_factor_temporal) + 1) * window_idx + elif window_idx > 0 and stride == num_frames: + compressed_start_frame -= window_idx + elif window_idx == 0: + compressed_start_frame = 0 + else: + raise ValueError(f"stride: {stride}, num_frames: {num_frames}") + compressed_frame_idx = compressed_start_frame + i + frame_accumulator[:, compressed_frame_idx] += latents[:, i] + frame_counts[compressed_frame_idx] += 1 + + + for i in range(num_frame_latents): + if frame_counts[i] > 0: + frame_accumulator[:, i] /= frame_counts[i] + + if not output_type == "latent": + video = self.decode_latents(frame_accumulator) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = frame_accumulator + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) \ No newline at end of file diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch.py new file mode 100755 index 0000000..6de883c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch.py @@ -0,0 +1,1044 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PipelineImageInput +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.branch_cogvideox import CogvideoXBranchModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + +# import matplotlib.pyplot as plt +# import numpy as np +# import os + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXI2VDualInpaintPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + branch ([`CogvideoXBranchModel`]): + A branch model to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + branch: CogvideoXBranchModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, branch=branch + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + image: torch.Tensor = None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae.config.scaling_factor * image_latents + + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + else: + video_latents = None + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents, image_latents) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def prepare_video_latents( + self, video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + + video = video.to(device=device, dtype=dtype) + + if video.shape[1] == 4: + video_latents = video + else: + video_latents = self._encode_vae_video(video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + + if video_latents.shape[0] < batch_size: + if not batch_size % video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + + video_latents = torch.cat([video_latents] * 2) if do_classifier_free_guidance else video_latents + + # aligning device to prevent device errors when concating it with the latent model input + video_latents = video_latents.to(device=device, dtype=dtype) + return video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + # inpainting args + video: List[PIL.Image.Image] = None, + condition_video: List[PIL.Image.Image] = None, + condition_video_latents: torch.Tensor = None, + strength: float = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + conditioning_scale: Union[float, List[float]] = 1.0, + mask_background: Optional[bool] = False, + add_first: Optional[bool] = False, + wo_text: Optional[bool] = False, + id_pool_resample_learnable: Optional[bool] = False, + mask_add: Optional[bool] = False, + replace_gt:Optional[bool] = False, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, timesteps=timesteps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + + # 5. Preprocess mask and video + crops_coords = None + resize_mode = "default" + + original_video = video + init_video = self.video_processor.preprocess_video(original_video, height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels // 2 + return_video_latents = latent_channels == 16 + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + image=image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=(return_video_latents and not is_strength_max), # anton, to save time + ) # [B, F, C, H, W] + + if (return_video_latents and not is_strength_max): # anton, to save time + latents, image_latents, noise, video_latents = latents_outputs + else: + latents, image_latents, noise = latents_outputs + + + # 7. Prepare mask latent variables + condition_video = self.video_processor.preprocess_video(condition_video, height=height, width=width) + + + condition_video_latents = self.prepare_video_latents( + condition_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_video_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_video_input = self.scheduler.scale_model_input(latent_video_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_video_input, latent_image_input], dim=2) + + # latent_branch_input = torch.cat([masked_video_latents, mask[:, :, :1, :, :]], dim=-3) + latent_branch_input = condition_video_latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + branch_block_samples = self.branch( + hidden_states=latent_video_input, + encoder_hidden_states=prompt_embeds, + branch_cond=latent_branch_input, + branch_mode=control_mode, + conditioning_scale=conditioning_scale, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + wo_text=wo_text, + return_dict=False, + )[0] + branch_block_samples = [block_sample.to(dtype=prompt_embeds.dtype) for block_sample in branch_block_samples] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + branch_block_samples=branch_block_samples, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + add_first=add_first, + branch_block_masks=mask[:, :, :1, :, :] if mask_add else None, + id_pool_resample_learnable=id_pool_resample_learnable, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + + # inpainting + latents = latents.to(device) + + # if replace_gt: + # init_latents_proper = video_latents.to(device) + # if do_classifier_free_guidance: + # init_mask, _ = mask.chunk(2) + # else: + # init_mask = mask + # init_mask = init_mask.to(device) + + # if i < len(timesteps) - 1: + # noise_timestep = timesteps[i + 1] + # init_latents_proper = self.scheduler.add_noise( + # init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + # ).to(device) + # before_latents = latents + # if mask_background: + # latents = init_mask * init_latents_proper + (1 - init_mask) * latents + # else: + # latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # mask = callback_outputs.pop("mask", mask) + condition_video_latents = callback_outputs.pop("condition_video_latents", condition_video_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch_anyl.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch_anyl.py new file mode 100755 index 0000000..8349614 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch_anyl.py @@ -0,0 +1,1102 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PipelineImageInput +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.branch_cogvideox import CogvideoXBranchModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXI2VDualInpaintAnyLPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + branch ([`CogvideoXBranchModel`]): + A branch model to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + branch: CogvideoXBranchModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, branch=branch + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + image: torch.Tensor = None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if len(image.shape) == 4: + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae.config.scaling_factor * image_latents + + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents, image_latents) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def prepare_video_latents( + self, video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + + video = video.to(device=device, dtype=dtype) + + if video.shape[1] == 4: + video_latents = video + else: + video_latents = self._encode_vae_video(video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + + if video_latents.shape[0] < batch_size: + if not batch_size % video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + + video_latents = torch.cat([video_latents] * 2) if do_classifier_free_guidance else video_latents + + # aligning device to prevent device errors when concating it with the latent model input + video_latents = video_latents.to(device=device, dtype=dtype) + return video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + video: List[PIL.Image.Image] = None, + condition_video: List[PIL.Image.Image] = None, + condition_video_latents: torch.Tensor = None, + strength: float = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + conditioning_scale: Union[float, List[float]] = 1.0, + mask_background: Optional[bool] = False, + add_first: Optional[bool] = False, + wo_text: Optional[bool] = False, + id_pool_resample_learnable: Optional[bool] = False, + mask_add: Optional[bool] = False, + replace_gt:Optional[bool] = False, + stride: int = 24, + prev_clip_weight: Optional[float] = 0.0, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(video, (list, tuple)): + total_frames = len(video) + else: + total_frames = video.shape[1] + + n_windows = (total_frames - num_frames) // stride + 1 + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs + self.check_inputs( + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, timesteps=timesteps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + + if stride < num_frames: + num_frame_latents = ((num_frames - 1) // self.vae_scale_factor_temporal + 1) * n_windows - (n_windows - 1) * ((num_frames - stride) // self.vae_scale_factor_temporal + 1) + elif stride == num_frames: + num_frame_latents = ((num_frames - 1) // self.vae_scale_factor_temporal) * n_windows + 1 + else: + raise ValueError(f"stride: {stride}, num_frames: {num_frames}") + latent_channels = self.transformer.config.in_channels // 2 + frame_counts = torch.zeros(num_frame_latents, device=device) + frame_accumulator = torch.zeros( + (batch_size, num_frame_latents, latent_channels, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial), + device=device, + dtype=prompt_embeds.dtype + ) + + prev_window_states_dict = {} + for window_idx in range(n_windows): + start_frame = window_idx * stride + end_frame = start_frame + num_frames + + if isinstance(video, (list, tuple)): + window_video = video[start_frame:end_frame] + else: + window_video = video[:, start_frame:end_frame] + + if isinstance(condition_video, (list, tuple)): + window_condition_video = condition_video[start_frame:end_frame] + else: + window_condition_video = condition_video[:, start_frame:end_frame] + + # 5. Prepare latents + window_video = self.video_processor.preprocess_video(window_video, height=height, width=width) + window_video = window_video.to(dtype=torch.float32) + if window_idx == 0: + image_ = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + else: + if -int((num_frames - stride) // self.vae_scale_factor_temporal) < 0: + image_ = latents[:, -int((num_frames - stride) // self.vae_scale_factor_temporal) - 1:-int((num_frames - stride) // self.vae_scale_factor_temporal) + 1 - 1] + # print(f"image was set by the frame {-int((num_frames - stride) // self.vae_scale_factor_temporal) - 1} of the previous window latents, num_frames: {num_frames}, stride: {stride}") + else: + image_ = latents[:, -1:] + # print(f"image was set by the last frame of the previous window latents") + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents=None, # must be None for all windows + video=window_video, + image=image_, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=latent_channels == 16, + ) + + if latent_channels == 16: + latents, image_latents, noise, video_latents = latents_outputs + else: + latents, image_latents, noise = latents_outputs + + # 7. Prepare condition video latents + condition_video = self.video_processor.preprocess_video(window_condition_video, height=height, width=width) + if condition_video_latents is not None: + condition_video_latents = None + + condition_video_latents = self.prepare_video_latents( + condition_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # ? + # mask = mask.permute(0, 2, 1, 3, 4).repeat(1, 1, latent_channels, 1, 1) + + # 8. Prepare extra step kwargs + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + old_pred_original_sample = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_video_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_video_input = self.scheduler.scale_model_input(latent_video_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_video_input, latent_image_input], dim=2) + timestep = t.expand(latent_model_input.shape[0]) + + # branch block + latent_branch_input = condition_video_latents + + branch_block_samples = self.branch( + hidden_states=latent_video_input, + encoder_hidden_states=prompt_embeds, + branch_cond=latent_branch_input, + branch_mode=control_mode, + conditioning_scale=conditioning_scale, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + wo_text=wo_text, + return_dict=False, + )[0] + + branch_block_samples = [block_sample.to(dtype=prompt_embeds.dtype) for block_sample in branch_block_samples] + + current_attention_kwargs = attention_kwargs.copy() if attention_kwargs else {} + if window_idx > 0: + current_attention_kwargs["prev_hidden_states"] = prev_window_states_dict[0] + current_attention_kwargs["prev_clip_weight"] = prev_clip_weight + # current_attention_kwargs["prev_resample_mask"] = prev_resample_mask #? + noise_pred, hidden_states_list = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + branch_block_samples=branch_block_samples, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=current_attention_kwargs, + add_first=add_first, + branch_block_masks=mask[:, :, :1, :, :] if mask_add else None, + id_pool_resample_learnable=id_pool_resample_learnable, + return_hidden_states=True, + return_resample_mask=False, # previous was true! + return_dict=False, + ) + noise_pred = noise_pred.float() + if window_idx < n_windows - 1: + if t.item() == timesteps[-1]: + prev_window_states_dict[0] = { + layer_idx: hidden_states.detach() + for layer_idx, hidden_states in enumerate(hidden_states_list) + } + # prev_resample_mask = prev_resample_mask.detach() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # inpainting + latents = latents.to(device) + + if replace_gt: + init_latents_proper = video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + + if mask_background: + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_video_latents = callback_outputs.pop("masked_video_latents", masked_video_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + for i in range(latents.shape[1]): + compressed_start_frame = window_idx * latents.shape[1] + if window_idx > 0 and stride < num_frames: + compressed_start_frame -= (int((num_frames - stride) // self.vae_scale_factor_temporal) + 1) * window_idx + elif window_idx > 0 and stride == num_frames: + compressed_start_frame -= window_idx + elif window_idx ==0: + compressed_start_frame = 0 + else: + raise ValueError(f"stride: {stride}, num_frames: {num_frames}") + compressed_frame_idx = compressed_start_frame + i + frame_accumulator[:, compressed_frame_idx] += latents[:, i] + frame_counts[compressed_frame_idx] += 1 + + + for i in range(num_frame_latents): + if frame_counts[i] > 0: + frame_accumulator[:, i] /= frame_counts[i] + + if not output_type == "latent": + video = self.decode_latents(frame_accumulator) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = frame_accumulator + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) \ No newline at end of file diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch_anyl_monstr.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch_anyl_monstr.py new file mode 100755 index 0000000..fc79ec3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_i2v_branch_anyl_monstr.py @@ -0,0 +1,1107 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PipelineImageInput +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.branch_cogvideox import CogvideoXBranchModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXI2VDualInpaintAnyLPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + branch ([`CogvideoXBranchModel`]): + A branch model to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + branch: CogvideoXBranchModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, branch=branch + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + image: torch.Tensor = None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if len(image.shape) == 4: + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae.config.scaling_factor * image_latents + + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents, image_latents) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def prepare_video_latents( + self, video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + + video = video.to(device=device, dtype=dtype) + + if video.shape[1] == 4: + video_latents = video + else: + video_latents = self._encode_vae_video(video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + + if video_latents.shape[0] < batch_size: + if not batch_size % video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + + video_latents = torch.cat([video_latents] * 2) if do_classifier_free_guidance else video_latents + + # aligning device to prevent device errors when concating it with the latent model input + video_latents = video_latents.to(device=device, dtype=dtype) + return video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + video: List[PIL.Image.Image] = None, + condition_video: List[PIL.Image.Image] = None, + condition_video_latents: torch.Tensor = None, + strength: float = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + conditioning_scale: Union[float, List[float]] = 1.0, + mask_background: Optional[bool] = False, + add_first: Optional[bool] = False, + wo_text: Optional[bool] = False, + id_pool_resample_learnable: Optional[bool] = False, + mask_add: Optional[bool] = False, + replace_gt:Optional[bool] = False, + stride: int = 24, + prev_clip_weight: Optional[float] = 0.0, + window_idx=None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(video, (list, tuple)): + total_frames = len(video) + else: + total_frames = video.shape[1] + + n_windows = (total_frames - num_frames) // stride + 1 + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs + self.check_inputs( + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, timesteps=timesteps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + + if stride < num_frames: + num_frame_latents = ((num_frames - 1) // self.vae_scale_factor_temporal + 1) * n_windows - (n_windows - 1) * ((num_frames - stride) // self.vae_scale_factor_temporal + 1) + elif stride == num_frames: + num_frame_latents = ((num_frames - 1) // self.vae_scale_factor_temporal) * n_windows + 1 + else: + raise ValueError(f"stride: {stride}, num_frames: {num_frames}") + latent_channels = self.transformer.config.in_channels // 2 + frame_counts = torch.zeros(num_frame_latents, device=device) + frame_accumulator = torch.zeros( + (batch_size, num_frame_latents, latent_channels, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial), + device=device, + dtype=prompt_embeds.dtype + ) + + prev_window_states_dict = {} + for _ in range(1): + + start_frame = window_idx * stride + end_frame = start_frame + num_frames + + if isinstance(video, (list, tuple)): + window_video = video[start_frame:end_frame] + else: + window_video = video[:, start_frame:end_frame] + + if isinstance(condition_video, (list, tuple)): + window_condition_video = condition_video[start_frame:end_frame] + else: + window_condition_video = condition_video[:, start_frame:end_frame] + + # 5. Prepare latents + window_video = self.video_processor.preprocess_video(window_video, height=height, width=width) + window_video = window_video.to(dtype=torch.float32) + if window_idx == 0: + image_ = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + else: + if -int((num_frames - stride) // self.vae_scale_factor_temporal) < 0: + image_ = self.latents[:, -int((num_frames - stride) // self.vae_scale_factor_temporal) - 1:-int((num_frames - stride) // self.vae_scale_factor_temporal) + 1 - 1] + # print(f"image was set by the frame {-int((num_frames - stride) // self.vae_scale_factor_temporal) - 1} of the previous window latents, num_frames: {num_frames}, stride: {stride}") + else: + image_ = self.latents[:, -1:] + # print(f"image was set by the last frame of the previous window latents") + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents=None, # must be None for all windows + video=window_video, + image=image_, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=latent_channels == 16, + ) + + if latent_channels == 16: + latents, image_latents, noise, video_latents = latents_outputs + else: + latents, image_latents, noise = latents_outputs + + self.latents = latents + + # 7. Prepare condition video latents + condition_video = self.video_processor.preprocess_video(window_condition_video, height=height, width=width) + if condition_video_latents is not None: + condition_video_latents = None + + condition_video_latents = self.prepare_video_latents( + condition_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # ? + # mask = mask.permute(0, 2, 1, 3, 4).repeat(1, 1, latent_channels, 1, 1) + + # 8. Prepare extra step kwargs + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + old_pred_original_sample = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_video_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_video_input = self.scheduler.scale_model_input(latent_video_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_video_input, latent_image_input], dim=2) + timestep = t.expand(latent_model_input.shape[0]) + + # branch block + latent_branch_input = condition_video_latents + + branch_block_samples = self.branch( + hidden_states=latent_video_input, + encoder_hidden_states=prompt_embeds, + branch_cond=latent_branch_input, + branch_mode=control_mode, + conditioning_scale=conditioning_scale, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + wo_text=wo_text, + return_dict=False, + )[0] + + branch_block_samples = [block_sample.to(dtype=prompt_embeds.dtype) for block_sample in branch_block_samples] + + current_attention_kwargs = attention_kwargs.copy() if attention_kwargs else {} + if window_idx > 0: + current_attention_kwargs["prev_hidden_states"] = prev_window_states_dict[0] + current_attention_kwargs["prev_clip_weight"] = prev_clip_weight + # current_attention_kwargs["prev_resample_mask"] = prev_resample_mask #? + noise_pred, hidden_states_list = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + branch_block_samples=branch_block_samples, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=current_attention_kwargs, + add_first=add_first, + branch_block_masks=mask[:, :, :1, :, :] if mask_add else None, + id_pool_resample_learnable=id_pool_resample_learnable, + return_hidden_states=True, + return_resample_mask=False, # previous was true! + return_dict=False, + ) + noise_pred = noise_pred.float() + if window_idx < n_windows - 1: + if t.item() == timesteps[-1]: + prev_window_states_dict[0] = { + layer_idx: hidden_states.detach() + for layer_idx, hidden_states in enumerate(hidden_states_list) + } + self.prev_window_states_dict = prev_window_states_dict + # prev_resample_mask = prev_resample_mask.detach() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # inpainting + latents = latents.to(device) + + if replace_gt: + init_latents_proper = video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + + if mask_background: + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_video_latents = callback_outputs.pop("masked_video_latents", masked_video_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + for i in range(latents.shape[1]): + compressed_start_frame = window_idx * latents.shape[1] + if window_idx > 0 and stride < num_frames: + compressed_start_frame -= (int((num_frames - stride) // self.vae_scale_factor_temporal) + 1) * window_idx + elif window_idx > 0 and stride == num_frames: + compressed_start_frame -= window_idx + elif window_idx ==0: + compressed_start_frame = 0 + else: + raise ValueError(f"stride: {stride}, num_frames: {num_frames}") + compressed_frame_idx = compressed_start_frame + i + frame_accumulator[:, compressed_frame_idx] += latents[:, i] + frame_counts[compressed_frame_idx] += 1 + + + for i in range(num_frame_latents): + if frame_counts[i] > 0: + frame_accumulator[:, i] /= frame_counts[i] + + if not output_type == "latent": + video = self.decode_latents(frame_accumulator) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = frame_accumulator + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) \ No newline at end of file diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_selfguidance.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_selfguidance.py new file mode 100755 index 0000000..1096f36 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_selfguidance.py @@ -0,0 +1,955 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXSelfGuidanceInpaintPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + # masked_video_latents = ( + # torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + # ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + # inpainting args + video: List[PIL.Image.Image] = None, + masked_video: List[PIL.Image.Image] = None, + masked_video_latents: torch.Tensor = None, + strength: float = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + conditioning_scale: Union[float, List[float]] = 1.0, + mask_background: Optional[bool] = False, + add_first: Optional[bool] = False, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + + # 5. Preprocess mask and video + crops_coords = None + resize_mode = "default" + + original_video = video + init_video = self.video_processor.preprocess_video(original_video, height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + return_video_latents = latent_channels == 16 + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_video_latents, + ) # [B, F, C, H, W] + + if return_video_latents: + latents, noise, video_latents = latents_outputs + else: + latents, noise = latents_outputs + + + # 7. Prepare mask latent variables + mask_condition = self.masked_video_processor.preprocess_video(masked_video, height=height, width=width) + if masked_video_latents is None: + if mask_background: + masked_video = init_video * (mask_condition >= 0.5) + else: + masked_video = init_video * (mask_condition < 0.5) + else: + masked_video = masked_video_latents + + mask, masked_video_latents = self.prepare_mask_latents( + mask_condition, + masked_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + mask = mask.permute(0, 2, 1, 3, 4) + mask = mask.repeat(1, 1, latent_channels, 1, 1) + + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + + latents = self.scheduler.add_noise( + masked_video_latents.to(device), noise.to(device), torch.tensor([timesteps[0]]).to(device) + ).to(device) + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # first round of masked video noist latent forward pass + init_masked_video_latents = torch.cat([masked_video_latents.to(device)] * 2) if do_classifier_free_guidance else masked_video_latents.to(device) + latent_masked_video_model_input = self.scheduler.add_noise( + init_masked_video_latents.to(device), noise.to(device), torch.tensor([t]).to(device) + ).to(device) + latent_masked_video_model_input = self.scheduler.scale_model_input(latent_masked_video_model_input, t) + + # predict noise model_output + noise_pred, masked_video_hidden_states = self.transformer( + hidden_states=latent_masked_video_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_hidden_states=True, + return_dict=False, + ) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + self_guidance_hidden_states=masked_video_hidden_states, + self_guidance_masks=mask, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + + # inpainting + latents = latents.to(device) + init_latents_proper = masked_video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + masked_video_latents.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + if mask_background: + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_video_latents = callback_outputs.pop("masked_video_latents", masked_video_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_sft.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_sft.py new file mode 100755 index 0000000..e7c0a34 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_inpainting_sft.py @@ -0,0 +1,934 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from PIL import Image, ImageFilter, ImageOps +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXTransformer3DInpaintModel +from ...models.branch_cogvideox import CogvideoXBranchModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class CogVideoXSFTInpaintPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->branch->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + branch: CogVideoXTransformer3DInpaintModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, branch=branch + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.masked_video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because `max_sequence_length` is set to " + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if (video is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of video + Noise." + "However, either the video or the noise timestep has not been provided." + ) + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + + if video.shape[-3] == 16: + video_latents = video + else: + video_latents = self._encode_vae_video(video=video, generator=generator) + video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1) + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def _encode_vae_video(self, video: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + video_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(video.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + video_latents = self.vae.config.scaling_factor * video_latents + + return video_latents.permute(0, 2, 1, 3, 4) + + def prepare_mask_latents( + self, mask, masked_video, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=((mask.shape[-3] - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial) + ) + # raise ValueError("Stop here") + mask = mask.to(device=device, dtype=dtype) + + masked_video = masked_video.to(device=device, dtype=dtype) + + if masked_video.shape[1] == 4: + masked_video_latents = masked_video + else: + masked_video_latents = self._encode_vae_video(masked_video, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1, 1) + if masked_video_latents.shape[0] < batch_size: + if not batch_size % masked_video_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_video_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_video_latents = masked_video_latents.repeat(batch_size // masked_video_latents.shape[0], 1, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_video_latents = masked_video_latents.to(device=device, dtype=dtype) + return mask, masked_video_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.branch.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.branch.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.branch.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.branch.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.branch.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.branch.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.branch.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + # inpainting args + video: List[PIL.Image.Image] = None, + masked_video: List[PIL.Image.Image] = None, + masked_video_latents: torch.Tensor = None, + strength: float = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + conditioning_scale: Union[float, List[float]] = 1.0, + mask_background: Optional[bool] = False, + add_first: Optional[bool] = False, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.branch.config.sample_size * self.vae_scale_factor_spatial + width = width or self.branch.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + + # 5. Preprocess mask and video + crops_coords = None + resize_mode = "default" + + original_video = video + init_video = self.video_processor.preprocess_video(original_video, height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + + # 5. Prepare latents. + latent_channels = self.branch.config.in_channels + return_video_latents = latent_channels == 16 + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_video_latents, + ) # [B, F, C, H, W] + + if return_video_latents: + latents, noise, video_latents = latents_outputs + else: + latents, noise = latents_outputs + + + # 7. Prepare mask latent variables + mask_condition = self.masked_video_processor.preprocess_video(masked_video, height=height, width=width) + if masked_video_latents is None: + if mask_background: + masked_video = init_video * (mask_condition >= 0.5) + else: + masked_video = init_video * (mask_condition < 0.5) + else: + masked_video = masked_video_latents + + mask, masked_video_latents = self.prepare_mask_latents( + mask_condition, + masked_video, + batch_size * num_videos_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + mask = mask.permute(0, 2, 1, 3, 4).repeat(1, 1, latent_channels, 1, 1) + + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.branch.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # branch block + # mask = 1.0 - mask + latent_branch_input = torch.cat([latent_model_input, masked_video_latents, mask[:, :, :1, :, :]], dim=-3) + + noise_pred = self.branch( + hidden_states=latent_branch_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + add_first=add_first, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + + # inpainting + latents = latents.to(device) + init_latents_proper = video_latents.to(device) + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + init_mask = init_mask.to(device) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper.to(device), noise.to(device), torch.tensor([noise_timestep]).to(device) + ).to(device) + before_latents = latents + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_video_latents = callback_outputs.pop("masked_video_latents", masked_video_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py new file mode 100755 index 0000000..6491998 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -0,0 +1,821 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + >>> input_video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) + >>> prompt = ( + ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " + ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " + ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " + ... "moons, but the remainder of the scene is mostly realistic." + ... ) + + >>> video = pipe( + ... video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50 + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for video-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 60, + width: int = 90, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, + ): + num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + init_latents = self.vae.config.scaling_factor * init_latents + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + strength, + negative_prompt, + callback_on_step_end_tensor_inputs, + video=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: List[Image.Image] = None, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + strength: float = 0.8, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + video (`List[PIL.Image.Image]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + strength, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=prompt_embeds.dtype) + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + latent_timestep, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/cogvideo/pipeline_output.py b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_output.py new file mode 100755 index 0000000..3de030d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/cogvideo/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class CogVideoXPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/src/diffusers/pipelines/consistency_models/__init__.py b/diffusers/src/diffusers/pipelines/consistency_models/__init__.py new file mode 100755 index 0000000..162d91c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/consistency_models/__init__.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, +) + + +_import_structure = { + "pipeline_consistency_models": ["ConsistencyModelPipeline"], +} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_consistency_models import ConsistencyModelPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/diffusers/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py new file mode 100755 index 0000000..d2f67a6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -0,0 +1,275 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import torch + +from ...models import UNet2DModel +from ...schedulers import CMStochasticIterativeScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import ConsistencyModelPipeline + + >>> device = "cuda" + >>> # Load the cd_imagenet64_l2 checkpoint. + >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" + >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe.to(device) + + >>> # Onestep Sampling + >>> image = pipe(num_inference_steps=1).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample.png") + + >>> # Onestep sampling, class-conditional image generation + >>> # ImageNet-64 class label 145 corresponds to king penguins + >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") + + >>> # Multistep sampling, class-conditional image generation + >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo: + >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 + >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") + ``` +""" + + +class ConsistencyModelPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional or class-conditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only + compatible with [`CMStochasticIterativeScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + ) + + self.safety_checker = None + + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess_image(self, sample: torch.Tensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + + # Output_type must be 'pil' + sample = self.numpy_to_pil(sample) + return sample + + def prepare_class_labels(self, batch_size, device, class_labels=None): + if self.unet.config.num_class_embeds is not None: + if isinstance(class_labels, list): + class_labels = torch.tensor(class_labels, dtype=torch.int) + elif isinstance(class_labels, int): + assert batch_size == 1, "Batch size must be 1 if classes is an int" + class_labels = torch.tensor([class_labels], dtype=torch.int) + elif class_labels is None: + # Randomly generate batch_size class labels + # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils + class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + class_labels = class_labels.to(device) + else: + class_labels = None + return class_labels + + def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + logger.warning( + f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" + " `timesteps` will be used over `num_inference_steps`." + ) + + if latents is not None: + expected_shape = (batch_size, 3, img_size, img_size) + if latents.shape != expected_shape: + raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + batch_size: int = 1, + class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, + num_inference_steps: int = 1, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): + Optional class labels for conditioning class-conditional consistency models. Not used if the model is + not class-conditional. + num_inference_steps (`int`, *optional*, defaults to 1): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Prepare call parameters + img_size = self.unet.config.sample_size + device = self._execution_device + + # 1. Check inputs + self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) + + # 2. Prepare image latents + # Sample image latents x_0 ~ N(0, sigma_0^2 * I) + sample = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=img_size, + width=img_size, + dtype=self.unet.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 3. Handle class_labels for class-conditional models + class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Denoising loop + # Multistep sampling: implements Algorithm 1 in the paper + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + scaled_sample = self.scheduler.scale_model_input(sample, t) + model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] + + sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] + + # call the callback, if provided + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + # 6. Post-process image sample + image = self.postprocess_image(sample, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/controlnet/__init__.py b/diffusers/src/diffusers/pipelines/controlnet/__init__.py new file mode 100755 index 0000000..b167105 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/__init__.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["multicontrolnet"] = ["MultiControlNetModel"] + _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] + _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] + _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .multicontrolnet import MultiControlNetModel + from .pipeline_controlnet import StableDiffusionControlNetPipeline + from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline + from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline + from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/controlnet/multicontrolnet.py b/diffusers/src/diffusers/pipelines/controlnet/multicontrolnet.py new file mode 100755 index 0000000..e3c5ec6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -0,0 +1,183 @@ +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...models.controlnet import ControlNetModel, ControlNetOutput +from ...models.modeling_utils import ModelMixin +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + for idx, controlnet in enumerate(self.nets): + suffix = "" if idx == 0 else f"_{idx}" + controlnet.save_pretrained( + save_directory + suffix, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_path (`os.PathLike`): + A path to a *directory* containing model weights saved using + [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., + `./my_model_directory/controlnet`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + """ + idx = 0 + controlnets = [] + + # load controlnet and append to list until no controlnet directory exists anymore + # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` + # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... + model_path_to_load = pretrained_model_path + while os.path.isdir(model_path_to_load): + controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) + controlnets.append(controlnet) + + idx += 1 + model_path_to_load = pretrained_model_path + f"_{idx}" + + logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") + + if len(controlnets) == 0: + raise ValueError( + f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." + ) + + return cls(controlnets) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet.py new file mode 100755 index 0000000..9b2fefe --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -0,0 +1,1348 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py new file mode 100755 index 0000000..86e0dde --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py @@ -0,0 +1,413 @@ +# Copyright 2024 Salesforce.com, inc. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..blip_diffusion.blip_image_processing import BlipImageProcessor +from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel +from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionControlNetPipeline + >>> from diffusers.utils import load_image + >>> from controlnet_aux import CannyDetector + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( + ... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> style_subject = "flower" + >>> tgt_subject = "teapot" + >>> text_prompt = "on a marble table" + + >>> cldm_cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg" + ... ).resize((512, 512)) + >>> canny = CannyDetector() + >>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil") + >>> style_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 50 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt, + ... style_image, + ... cldm_cond_image, + ... style_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionControlNetPipeline(DiffusionPipeline): + """ + Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + controlnet ([`ControlNetModel`]): + ControlNet model to get the conditioning image embedding. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + model_cpu_offload_seq = "qformer->text_encoder->unet->vae" + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + controlnet: ControlNetModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: List[float] = None, + std: List[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + controlnet=controlnet, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt, device=None): + device = device or self._execution_device + + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + # Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.image_processor.preprocess( + image, + size={"width": width, "height": height}, + do_rescale=True, + do_center_crop=False, + do_normalize=False, + return_tensors="pt", + )["pixel_values"].to(device) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: List[str], + reference_image: PIL.Image.Image, + condtioning_image: PIL.Image.Image, + source_subject_category: List[str], + target_subject_category: List[str], + latents: Optional[torch.Tensor] = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + neg_prompt: Optional[str] = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`List[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + condtioning_image (`PIL.Image.Image`): + The conditioning canny edge image to condition the generation on. + source_subject_category (`List[str]`): + The source subject category. + target_subject_category (`List[str]`): + The target subject category. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + seed (`int`, *optional*, defaults to 42): + The seed to use for random generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt, device) + # 3. unconditional embedding + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + cond_image = self.prepare_control_image( + image=condtioning_image, + width=width, + height=height, + batch_size=batch_size, + num_images_per_prompt=1, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=cond_image, + return_dict=False, + ) + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py new file mode 100755 index 0000000..2a4f46d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -0,0 +1,1319 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> np_image = np.array(image) + + >>> # get canny image + >>> np_image = cv2.Canny(np_image, 100, 200) + >>> np_image = np_image[:, :, None] + >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) + >>> canny_image = Image.fromarray(np_image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", + ... num_inference_steps=20, + ... generator=generator, + ... image=image, + ... control_image=canny_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_image(image): + if isinstance(image, torch.Tensor): + # Batch single image + if image.ndim == 3: + image = image.unsqueeze(0) + + image = image.to(dtype=torch.float32) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + return image + + +class StableDiffusionControlNetImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.8, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The initial image to be used as the starting point for the image generation process. Can also accept + image latents as `image`, and if passing latents directly they are not encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + # 5. Prepare controlnet_conditioning_image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py new file mode 100755 index 0000000..9f7d464 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -0,0 +1,1504 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_canny_condition(image): + ... image = np.array(image) + ... image = cv2.Canny(image, 100, 200) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... image = Image.fromarray(image) + ... return image + + + >>> control_image = make_canny_condition(init_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionControlNetInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + + + This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting + ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as + default text-to-image Stable Diffusion checkpoints + ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image + Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as + [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + callback_steps, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, + `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to be used as the starting point. For both + NumPy array and PyTorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a NumPy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, + `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a NumPy array or PyTorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for PyTorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for NumPy array, it would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, + W, 1)`, or `(H, W)`. + control_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, + `List[List[torch.Tensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + mask_image, + height, + width, + callback_steps, + output_type, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py new file mode 100755 index 0000000..17fd2cb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -0,0 +1,1845 @@ +# Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .multicontrolnet import MultiControlNetModel + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((1024, 1024)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((1024, 1024)) + + + >>> def make_canny_condition(image): + ... image = np.array(image) + ... image = cv2.Canny(image, 100, 200) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... image = Image.fromarray(image) + ... return image + + + >>> control_image = make_canny_condition(init_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLControlNetInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + masked_image_latents = None + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: Union[ + PipelineImageInput, + List[PipelineImageInput], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # # 0.0 Default height and width to unet + # height = height or self.unet.config.sample_size * self.vae_scale_factor + # width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 0.1 align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + control_image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.1 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 5. Preprocess mask and image - resizes image and mask w.r.t height and width + # 5.1 Prepare init image + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5.2 Prepare control images + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + raise ValueError(f"{controlnet.__class__} is not supported.") + + # 5.3 Prepare mask + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps if isinstance(controlnet, MultiControlNetModel) else keeps[0]) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + denoising_end is not None + and denoising_start is not None + and denoising_value_valid(denoising_end) + and denoising_value_valid(denoising_start) + and denoising_start >= denoising_end + ): + raise ValueError( + f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {denoising_end} when using type float." + ) + elif denoising_end is not None and denoising_value_valid(denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # # Resize control_image to match the size of the input to the controlnet + # if control_image.shape[-2:] != control_model_input.shape[-2:]: + # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py new file mode 100755 index 0000000..fdebcdf --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -0,0 +1,1587 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py new file mode 100755 index 0000000..af19f3c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -0,0 +1,1656 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # pip install accelerate transformers safetensors diffusers + + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from transformers import DPTImageProcessor, DPTForDepthEstimation + >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL + >>> from diffusers.utils import load_image + + + >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") + >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-depth-sdxl-1.0-small", + ... variant="fp16", + ... use_safetensors=True, + ... torch_dtype=torch.float16, + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... variant="fp16", + ... use_safetensors=True, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + + >>> def get_depth_map(image): + ... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + ... with torch.no_grad(), torch.autocast("cuda"): + ... depth_map = depth_estimator(image).predicted_depth + + ... depth_map = torch.nn.functional.interpolate( + ... depth_map.unsqueeze(1), + ... size=(1024, 1024), + ... mode="bicubic", + ... align_corners=False, + ... ) + ... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_map = (depth_map - depth_min) / (depth_max - depth_min) + ... image = torch.cat([depth_map] * 3, dim=1) + ... image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + ... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + ... return image + + + >>> prompt = "A robot, 4k photo" + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((1024, 1024)) + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> depth_image = get_depth_map(image) + + >>> images = pipe( + ... prompt, + ... image=image, + ... control_image=depth_image, + ... strength=0.99, + ... num_inference_steps=50, + ... controlnet_conditioning_scale=controlnet_conditioning_scale, + ... ).images + >>> images[0].save(f"robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLControlNetImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.8, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also + be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in + init, images must be passed as a list such that each element of the list can be correctly batched for + input to a single controlnet. + height (`int`, *optional*, defaults to the size of control_image): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to the size of control_image): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` + containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image and controlnet_conditioning_image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(control_image, list): + original_size = original_size or control_image[0].shape[-2:] + else: + original_size = original_size or control_image.shape[-2:] + target_size = target_size or (height, width) + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/diffusers/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py new file mode 100755 index 0000000..8a2cc08 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -0,0 +1,532 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ..stable_diffusion import FlaxStableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> from diffusers.utils import load_image, make_image_grid + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> # get canny image + >>> canny_image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" + ... ) + + >>> prompts = "best quality, extremely detailed" + >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" + + >>> # load control net and stable diffusion v1-5 + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 + ... ) + >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 + ... ) + >>> params["controlnet"] = controlnet_params + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + + >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) + >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> negative_prompt_ids = shard(negative_prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipe( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... num_inference_steps=50, + ... neg_prompt_ids=negative_prompt_ids, + ... jit=True, + ... ).images + + >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = make_image_grid(output_images, num_samples // 4, 4) + >>> output_images.save("generated_image.png") + ``` +""" + + +class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`FlaxControlNetModel`]: + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + controlnet: FlaxControlNetModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_text_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + return text_input.input_ids + + def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + return processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int, + guidance_scale: float, + latents: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, + controlnet_conditioning_scale: float = 1.0, + ): + height, width = image.shape[-2:] + if height % 64 != 0 or width % 64 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + image = jnp.concatenate([image] * 2) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet.apply( + {"params": params["controlnet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int = 50, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0, + return_dict: bool = True, + jit: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt_ids (`jnp.ndarray`): + The prompt or prompts to guide the image generation. + image (`jnp.ndarray`): + Array representing the ControlNet input condition to provide guidance to the `unet` for generation. + params (`Dict` or `FrozenDict`): + Dictionary containing the model parameters/weights. + prng_seed (`jax.Array`): + Array containing random number generator key. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + latents (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + array is generated by sampling using the supplied random `generator`. + controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + + + This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a + future release. + + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + + height, width = image.shape[-2:] + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if isinstance(controlnet_conditioning_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.array(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 5), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + image = image.convert("RGB") + w, h = image.size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return image diff --git a/diffusers/src/diffusers/pipelines/controlnet_hunyuandit/__init__.py b/diffusers/src/diffusers/pipelines/controlnet_hunyuandit/__init__.py new file mode 100755 index 0000000..34c5979 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_hunyuandit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/diffusers/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py new file mode 100755 index 0000000..10c521e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -0,0 +1,1042 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline + import torch + + controlnet = HunyuanDiT2DControlNetModel.from_pretrained( + "Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16 + ) + + pipe = HunyuanDiTControlNetPipeline.from_pretrained( + "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.to("cuda") + + from diffusers.utils import load_image + + cond_image = load_image( + "https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true" + ) + + ## You may also use English prompt as HunyuanDiT supports both English and Chinese + prompt = "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" + # prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere." + image = pipe( + prompt, + height=1024, + width=1024, + control_image=cond_image, + num_inference_steps=50, + ).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTControlNetPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + controlnet ([`HunyuanDiT2DControlNetModel`] or `List[HunyuanDiT2DControlNetModel]` or [`HunyuanDiT2DControlNetModel`]): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + controlnet: Union[ + HunyuanDiT2DControlNetModel, + List[HunyuanDiT2DControlNetModel], + Tuple[HunyuanDiT2DControlNetModel], + HunyuanDiT2DMultiControlNetModel, + ], + text_encoder_2=T5EncoderModel, + tokenizer_2=MT5Tokenizer, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + controlnet=controlnet, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare control image + if isinstance(self.controlnet, HunyuanDiT2DControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = control_image * self.vae.config.scaling_factor + + elif isinstance(self.controlnet, HunyuanDiT2DMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = control_image_ * self.vae.config.scaling_factor + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # controlnet(s) inference + control_block_samples = self.controlnet( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + )[0] + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + controlnet_block_samples=control_block_samples, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/controlnet_sd3/__init__.py b/diffusers/src/diffusers/pipelines/controlnet_sd3/__init__.py new file mode 100755 index 0000000..aeb61dc --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_sd3/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"] + _import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [ + "StableDiffusion3ControlNetInpaintingPipeline" + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline + from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/diffusers/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py new file mode 100755 index 0000000..f524602 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -0,0 +1,1097 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3ControlNetPipeline + >>> from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel + >>> from diffusers.utils import load_image + + >>> controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16) + + >>> pipe = StableDiffusion3ControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl holding a sign that says InstantX" + >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0] + >>> image.save("sd3.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + controlnet: Union[ + SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel + ], + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_pooled_projections: Optional[torch.FloatTensor] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + controlnet_pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of controlnet input conditions. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, SD3MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Prepare control image + if isinstance(self.controlnet, SD3ControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = control_image * self.vae.config.scaling_factor + + elif isinstance(self.controlnet, SD3MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = control_image_ * self.vae.config.scaling_factor + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + if controlnet_pooled_projections is None: + controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) + else: + controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet(s) inference + control_block_samples = self.controlnet( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=controlnet_pooled_projections, + joint_attention_kwargs=self.joint_attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + return_dict=False, + )[0] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=control_block_samples, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/diffusers/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py new file mode 100755 index 0000000..47fc6d6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -0,0 +1,1144 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The AlimamaCreative Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image, check_min_version + >>> from diffusers.pipelines import StableDiffusion3ControlNetInpaintingPipeline + >>> from diffusers.models.controlnet_sd3 import SD3ControlNetModel + + >>> controlnet = SD3ControlNetModel.from_pretrained( + ... "alimama-creative/SD3-Controlnet-Inpainting", use_safetensors=True, extra_conditioning_channels=1 + ... ) + >>> pipe = StableDiffusion3ControlNetInpaintingPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", + ... controlnet=controlnet, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.text_encoder.to(torch.float16) + >>> pipe.controlnet.to(torch.float16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/resolve/main/images/dog.png" + ... ) + >>> mask = load_image( + ... "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/resolve/main/images/dog_mask.png" + ... ) + >>> width = 1024 + >>> height = 1024 + >>> prompt = "A cat is sitting next to a puppy." + >>> generator = torch.Generator(device="cuda").manual_seed(24) + >>> res_image = pipe( + ... negative_prompt="deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW", + ... prompt=prompt, + ... height=height, + ... width=width, + ... control_image=image, + ... control_mask=mask, + ... num_inference_steps=28, + ... generator=generator, + ... controlnet_conditioning_scale=0.95, + ... guidance_scale=7, + ... ).images[0] + >>> res_image.save(f"sd3.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + controlnet: Union[ + SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel + ], + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_resize=True, + do_convert_grayscale=True, + do_normalize=False, + do_binarize=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + def prepare_image_with_mask( + self, + image, + mask, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + # Prepare image + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + # Prepare mask + if isinstance(mask, torch.Tensor): + pass + else: + mask = self.mask_processor.preprocess(mask, height=height, width=width) + mask = mask.repeat_interleave(repeat_by, dim=0) + mask = mask.to(device=device, dtype=dtype) + + # Get masked image + masked_image = image.clone() + masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 + + # Encode to latents + image_latents = self.vae.encode(masked_image).latent_dist.sample() + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + image_latents = image_latents.to(dtype) + + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = 1 - mask + control_image = torch.cat([image_latents, mask], dim=1) + + if do_classifier_free_guidance and not guess_mode: + control_image = torch.cat([control_image] * 2) + + return control_image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_image: PipelineImageInput = None, + control_mask: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_pooled_projections: Optional[torch.FloatTensor] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to + be masked out with `control_mask` and repainted according to `prompt`). For both numpy array and + pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the + expected shape should be `(B, C, H, W)`. If it is a numpy array or a list of arrays, the expected shape + should be `(B, H, W, C)` or `(H, W, C)`. + control_mask (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`. And + for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + controlnet_pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of controlnet input conditions. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, SD3MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Prepare control image + if isinstance(self.controlnet, SD3ControlNetModel): + control_image = self.prepare_image_with_mask( + image=control_image, + mask=control_mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + latent_height, latent_width = control_image.shape[-2:] + + height = latent_height * self.vae_scale_factor + width = latent_width * self.vae_scale_factor + + elif isinstance(self.controlnet, SD3MultiControlNetModel): + raise NotImplementedError("MultiControlNetModel is not supported for SD3ControlNetInpaintingPipeline.") + else: + assert False + + if controlnet_pooled_projections is None: + controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) + else: + controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet(s) inference + control_block_samples = self.controlnet( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=controlnet_pooled_projections, + joint_attention_kwargs=self.joint_attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + return_dict=False, + )[0] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=control_block_samples, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/controlnet_xs/__init__.py b/diffusers/src/diffusers/pipelines/controlnet_xs/__init__.py new file mode 100755 index 0000000..978278b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_xs/__init__.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] + _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/diffusers/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py new file mode 100755 index 0000000..ca10e65 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -0,0 +1,916 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetXSPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` and `controlnet_conditioning_scale` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.unet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.unet, UNetControlNetXSModel) + or is_compiled + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if is_controlnet_compiled and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/diffusers/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py new file mode 100755 index 0000000..326cfda --- /dev/null +++ b/diffusers/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -0,0 +1,1111 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionXLControlNetXSPipeline( + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` and ``controlnet_conditioning_scale`` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.unet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.unet, UNetControlNetXSModel) + or is_compiled + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is + returned, otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare image + if isinstance(unet, UNetControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if is_controlnet_compiled and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/dance_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/dance_diffusion/__init__.py new file mode 100755 index 0000000..0d3e466 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/dance_diffusion/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_dance_diffusion": ["DanceDiffusionPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_dance_diffusion import DanceDiffusionPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/diffusers/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py new file mode 100755 index 0000000..bcd36c4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -0,0 +1,156 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch + +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DanceDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for audio generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet1DModel`]): + A `UNet1DModel` to denoise the encoded audio. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`IPNDMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 100, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + audio_length_in_s: Optional[float] = None, + return_dict: bool = True, + ) -> Union[AudioPipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of audio samples to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher-quality audio sample at + the expense of slower inference. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): + The length of the generated audio sample in seconds. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + + Example: + + ```py + from diffusers import DiffusionPipeline + from scipy.io.wavfile import write + + model_id = "harmonai/maestro-150k" + pipe = DiffusionPipeline.from_pretrained(model_id) + pipe = pipe.to("cuda") + + audios = pipe(audio_length_in_s=4.0).audios + + # To save locally + for i, audio in enumerate(audios): + write(f"maestro_test_{i}.wav", pipe.unet.sample_rate, audio.transpose()) + + # To dislay in google colab + import IPython.display as ipd + + for audio in audios: + display(ipd.Audio(audio, rate=pipe.unet.sample_rate)) + ``` + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate + + sample_size = audio_length_in_s * self.unet.config.sample_rate + + down_scale_factor = 2 ** len(self.unet.up_blocks) + if sample_size < 3 * down_scale_factor: + raise ValueError( + f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" + f" {3 * down_scale_factor / self.unet.config.sample_rate}." + ) + + original_sample_size = int(sample_size) + if sample_size % down_scale_factor != 0: + sample_size = ( + (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 + ) * down_scale_factor + logger.info( + f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" + f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" + " process." + ) + sample_size = int(sample_size) + + dtype = next(self.unet.parameters()).dtype + shape = (batch_size, self.unet.config.in_channels, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + audio = randn_tensor(shape, generator=generator, device=self._execution_device, dtype=dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, device=audio.device) + self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(audio, t).sample + + # 2. compute previous audio sample: x_t -> t_t-1 + audio = self.scheduler.step(model_output, t, audio).prev_sample + + audio = audio.clamp(-1, 1).float().cpu().numpy() + + audio = audio[:, :, :original_sample_size] + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/src/diffusers/pipelines/ddim/__init__.py b/diffusers/src/diffusers/pipelines/ddim/__init__.py new file mode 100755 index 0000000..d9eede4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ddim/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_ddim": ["DDIMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_ddim import DDIMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/ddim/pipeline_ddim.py b/diffusers/src/diffusers/pipelines/ddim/pipeline_ddim.py new file mode 100755 index 0000000..a3b967e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -0,0 +1,154 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch + +from ...schedulers import DDIMScheduler +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDIMPipeline(DiffusionPipeline): + r""" + Pipeline for image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet, scheduler): + super().__init__() + + # make sure scheduler can always be converted to DDIM + scheduler = DDIMScheduler.from_config(scheduler.config) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + use_clipped_model_output: Optional[bool] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to + DDIM and `1` corresponds to DDPM. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + use_clipped_model_output (`bool`, *optional*, defaults to `None`): + If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed + downstream to the scheduler (use `None` for schedulers which don't support this argument). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DDIMPipeline + >>> import PIL.Image + >>> import numpy as np + + >>> # load model and scheduler + >>> pipe = DDIMPipeline.from_pretrained("fusing/ddim-lsun-bedroom") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe(eta=0.0, num_inference_steps=50) + + >>> # process image to PIL + >>> image_processed = image.cpu().permute(0, 2, 3, 1) + >>> image_processed = (image_processed + 1.0) * 127.5 + >>> image_processed = image_processed.numpy().astype(np.uint8) + >>> image_pil = PIL.Image.fromarray(image_processed[0]) + + >>> # save image + >>> image_pil.save("test.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # eta corresponds to η in paper and should be between [0, 1] + # do x_t -> x_t-1 + image = self.scheduler.step( + model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator + ).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/ddpm/__init__.py b/diffusers/src/diffusers/pipelines/ddpm/__init__.py new file mode 100755 index 0000000..eb41dd1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ddpm/__init__.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, +) + + +_import_structure = {"pipeline_ddpm": ["DDPMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_ddpm import DDPMPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/diffusers/src/diffusers/pipelines/ddpm/pipeline_ddpm.py new file mode 100755 index 0000000..093a3cd --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -0,0 +1,127 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch + +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDPMPipeline(DiffusionPipeline): + r""" + Pipeline for image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 1000, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DDPMPipeline + + >>> # load model and scheduler + >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + + >>> # save image + >>> image.save("ddpm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = randn_tensor(image_shape, generator=generator) + image = image.to(self.device) + else: + image = randn_tensor(image_shape, generator=generator, device=self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/__init__.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/__init__.py new file mode 100755 index 0000000..79aab1f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/__init__.py @@ -0,0 +1,85 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = { + "timesteps": [ + "fast27_timesteps", + "smart100_timesteps", + "smart185_timesteps", + "smart27_timesteps", + "smart50_timesteps", + "super100_timesteps", + "super27_timesteps", + "super40_timesteps", + ] +} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_if"] = ["IFPipeline"] + _import_structure["pipeline_if_img2img"] = ["IFImg2ImgPipeline"] + _import_structure["pipeline_if_img2img_superresolution"] = ["IFImg2ImgSuperResolutionPipeline"] + _import_structure["pipeline_if_inpainting"] = ["IFInpaintingPipeline"] + _import_structure["pipeline_if_inpainting_superresolution"] = ["IFInpaintingSuperResolutionPipeline"] + _import_structure["pipeline_if_superresolution"] = ["IFSuperResolutionPipeline"] + _import_structure["pipeline_output"] = ["IFPipelineOutput"] + _import_structure["safety_checker"] = ["IFSafetyChecker"] + _import_structure["watermark"] = ["IFWatermarker"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_if import IFPipeline + from .pipeline_if_img2img import IFImg2ImgPipeline + from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline + from .pipeline_if_inpainting import IFInpaintingPipeline + from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline + from .pipeline_if_superresolution import IFSuperResolutionPipeline + from .pipeline_output import IFPipelineOutput + from .safety_checker import IFSafetyChecker + from .timesteps import ( + fast27_timesteps, + smart27_timesteps, + smart50_timesteps, + smart100_timesteps, + smart185_timesteps, + super27_timesteps, + super40_timesteps, + super100_timesteps, + ) + from .watermark import IFWatermarker + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py new file mode 100755 index 0000000..f545b24 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -0,0 +1,774 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> safety_modules = { + ... "feature_extractor": pipe.feature_extractor, + ... "safety_checker": pipe.safety_checker, + ... "watermarker": pipe.watermarker, + ... } + >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 + ... ) + >>> super_res_2_pipe.enable_model_cpu_offload() + + >>> image = super_res_2_pipe( + ... prompt=prompt, + ... image=image, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + # 5. Prepare intermediate images + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + self.unet.config.in_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py new file mode 100755 index 0000000..0701791 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -0,0 +1,895 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.7, + num_inference_steps: int = 80, + timesteps: List[int] = None, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.7): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 80): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py new file mode 100755 index 0000000..6685ba6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -0,0 +1,1011 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warning( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + original_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.8, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.Tensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py new file mode 100755 index 0000000..7fca0bc --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -0,0 +1,1014 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.config.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + mask_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 1.0): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py new file mode 100755 index 0000000..4f04a1d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -0,0 +1,1121 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` + """ + + +class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + model_cpu_offload_seq = "text_encoder->unet" + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warning( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.config.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + original_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + mask_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.8, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.Tensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to 0): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. prepare mask image + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + mask_image, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py new file mode 100755 index 0000000..891963f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -0,0 +1,870 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warning( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def preprocess_image(self, image, num_images_per_prompt, device): + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: int = None, + width: int = None, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to None): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to None): + The width in pixels of the generated image. + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`): + The image to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*, defaults to None): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + # 5. Prepare intermediate images + num_channels = self.unet.config.in_channels // 2 + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + num_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare upscaled image and noise level + image = self.preprocess_image(image, num_images_per_prompt, device) + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 11. Convert to PIL + image = self.numpy_to_pil(image) + + # 12. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_output.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_output.py new file mode 100755 index 0000000..7f39ab5 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/pipeline_output.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class IFPipelineOutput(BaseOutput): + """ + Args: + Output class for Stable Diffusion pipelines. + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content or a watermark. `None` if safety checking could not be performed. + watermark_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety + checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_detected: Optional[List[bool]] + watermark_detected: Optional[List[bool]] diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/safety_checker.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/safety_checker.py new file mode 100755 index 0000000..8ffeed5 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/safety_checker.py @@ -0,0 +1,59 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class IFSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModelWithProjection(config.vision_config) + + self.p_head = nn.Linear(config.vision_config.projection_dim, 1) + self.w_head = nn.Linear(config.vision_config.projection_dim, 1) + + @torch.no_grad() + def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): + image_embeds = self.vision_model(clip_input)[0] + + nsfw_detected = self.p_head(image_embeds) + nsfw_detected = nsfw_detected.flatten() + nsfw_detected = nsfw_detected > p_threshold + nsfw_detected = nsfw_detected.tolist() + + if any(nsfw_detected): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, nsfw_detected_ in enumerate(nsfw_detected): + if nsfw_detected_: + images[idx] = np.zeros(images[idx].shape) + + watermark_detected = self.w_head(image_embeds) + watermark_detected = watermark_detected.flatten() + watermark_detected = watermark_detected > w_threshold + watermark_detected = watermark_detected.tolist() + + if any(watermark_detected): + logger.warning( + "Potential watermarked content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, watermark_detected_ in enumerate(watermark_detected): + if watermark_detected_: + images[idx] = np.zeros(images[idx].shape) + + return images, nsfw_detected, watermark_detected diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/timesteps.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/timesteps.py new file mode 100755 index 0000000..d44285c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/timesteps.py @@ -0,0 +1,579 @@ +fast27_timesteps = [ + 999, + 800, + 799, + 600, + 599, + 500, + 400, + 399, + 377, + 355, + 333, + 311, + 288, + 266, + 244, + 222, + 200, + 199, + 177, + 155, + 133, + 111, + 88, + 66, + 44, + 22, + 0, +] + +smart27_timesteps = [ + 999, + 976, + 952, + 928, + 905, + 882, + 858, + 857, + 810, + 762, + 715, + 714, + 572, + 429, + 428, + 286, + 285, + 238, + 190, + 143, + 142, + 118, + 95, + 71, + 47, + 24, + 0, +] + +smart50_timesteps = [ + 999, + 988, + 977, + 966, + 955, + 944, + 933, + 922, + 911, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 350, + 300, + 299, + 266, + 233, + 200, + 199, + 179, + 159, + 140, + 120, + 100, + 99, + 88, + 77, + 66, + 55, + 44, + 33, + 22, + 11, + 0, +] + +smart100_timesteps = [ + 999, + 995, + 992, + 989, + 985, + 981, + 978, + 975, + 971, + 967, + 964, + 961, + 957, + 956, + 951, + 947, + 942, + 937, + 933, + 928, + 923, + 919, + 914, + 913, + 908, + 903, + 897, + 892, + 887, + 881, + 876, + 871, + 870, + 864, + 858, + 852, + 846, + 840, + 834, + 828, + 827, + 820, + 813, + 806, + 799, + 792, + 785, + 784, + 777, + 770, + 763, + 756, + 749, + 742, + 741, + 733, + 724, + 716, + 707, + 699, + 698, + 688, + 677, + 666, + 656, + 655, + 645, + 634, + 623, + 613, + 612, + 598, + 584, + 570, + 569, + 555, + 541, + 527, + 526, + 505, + 484, + 483, + 462, + 440, + 439, + 396, + 395, + 352, + 351, + 308, + 307, + 264, + 263, + 220, + 219, + 176, + 132, + 88, + 44, + 0, +] + +smart185_timesteps = [ + 999, + 997, + 995, + 992, + 990, + 988, + 986, + 984, + 981, + 979, + 977, + 975, + 972, + 970, + 968, + 966, + 964, + 961, + 959, + 957, + 956, + 954, + 951, + 949, + 946, + 944, + 941, + 939, + 936, + 934, + 931, + 929, + 926, + 924, + 921, + 919, + 916, + 914, + 913, + 910, + 907, + 905, + 902, + 899, + 896, + 893, + 891, + 888, + 885, + 882, + 879, + 877, + 874, + 871, + 870, + 867, + 864, + 861, + 858, + 855, + 852, + 849, + 846, + 843, + 840, + 837, + 834, + 831, + 828, + 827, + 824, + 821, + 817, + 814, + 811, + 808, + 804, + 801, + 798, + 795, + 791, + 788, + 785, + 784, + 780, + 777, + 774, + 770, + 766, + 763, + 760, + 756, + 752, + 749, + 746, + 742, + 741, + 737, + 733, + 730, + 726, + 722, + 718, + 714, + 710, + 707, + 703, + 699, + 698, + 694, + 690, + 685, + 681, + 677, + 673, + 669, + 664, + 660, + 656, + 655, + 650, + 646, + 641, + 636, + 632, + 627, + 622, + 618, + 613, + 612, + 607, + 602, + 596, + 591, + 586, + 580, + 575, + 570, + 569, + 563, + 557, + 551, + 545, + 539, + 533, + 527, + 526, + 519, + 512, + 505, + 498, + 491, + 484, + 483, + 474, + 466, + 457, + 449, + 440, + 439, + 428, + 418, + 407, + 396, + 395, + 381, + 366, + 352, + 351, + 330, + 308, + 307, + 286, + 264, + 263, + 242, + 220, + 219, + 176, + 175, + 132, + 131, + 88, + 44, + 0, +] + +super27_timesteps = [ + 999, + 991, + 982, + 974, + 966, + 958, + 950, + 941, + 933, + 925, + 916, + 908, + 900, + 899, + 874, + 850, + 825, + 800, + 799, + 700, + 600, + 500, + 400, + 300, + 200, + 100, + 0, +] + +super40_timesteps = [ + 999, + 992, + 985, + 978, + 971, + 964, + 957, + 949, + 942, + 935, + 928, + 921, + 914, + 907, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 300, + 299, + 200, + 199, + 100, + 99, + 0, +] + +super100_timesteps = [ + 999, + 996, + 992, + 989, + 985, + 982, + 979, + 975, + 972, + 968, + 965, + 961, + 958, + 955, + 951, + 948, + 944, + 941, + 938, + 934, + 931, + 927, + 924, + 920, + 917, + 914, + 910, + 907, + 903, + 900, + 899, + 891, + 884, + 876, + 869, + 861, + 853, + 846, + 838, + 830, + 823, + 815, + 808, + 800, + 799, + 788, + 777, + 766, + 755, + 744, + 733, + 722, + 711, + 700, + 699, + 688, + 677, + 666, + 655, + 644, + 633, + 622, + 611, + 600, + 599, + 585, + 571, + 557, + 542, + 528, + 514, + 500, + 499, + 485, + 471, + 457, + 442, + 428, + 414, + 400, + 399, + 379, + 359, + 340, + 320, + 300, + 299, + 279, + 259, + 240, + 220, + 200, + 199, + 166, + 133, + 100, + 99, + 66, + 33, + 0, +] diff --git a/diffusers/src/diffusers/pipelines/deepfloyd_if/watermark.py b/diffusers/src/diffusers/pipelines/deepfloyd_if/watermark.py new file mode 100755 index 0000000..e03e3fa --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deepfloyd_if/watermark.py @@ -0,0 +1,46 @@ +from typing import List + +import PIL.Image +import torch +from PIL import Image + +from ...configuration_utils import ConfigMixin +from ...models.modeling_utils import ModelMixin +from ...utils import PIL_INTERPOLATION + + +class IFWatermarker(ModelMixin, ConfigMixin): + def __init__(self): + super().__init__() + + self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) + self.watermark_image_as_pil = None + + def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): + # Copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 + + h = images[0].height + w = images[0].width + + sample_size = sample_size or h + + coef = min(h / sample_size, w / sample_size) + img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) + + S1, S2 = 1024**2, img_w * img_h + K = (S2 / S1) ** 0.5 + wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) + + if self.watermark_image_as_pil is None: + watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() + watermark_image = Image.fromarray(watermark_image, mode="RGBA") + self.watermark_image_as_pil = watermark_image + + wm_img = self.watermark_image_as_pil.resize( + (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None + ) + + for pil_img in images: + pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) + + return images diff --git a/diffusers/src/diffusers/pipelines/deprecated/README.md b/diffusers/src/diffusers/pipelines/deprecated/README.md new file mode 100755 index 0000000..1e21dbb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/README.md @@ -0,0 +1,3 @@ +# Deprecated Pipelines + +This folder contains pipelines that have very low usage as measured by model downloads, issues and PRs. While you can still use the pipelines just as before, we will stop testing the pipelines and will not accept any changes to existing files. \ No newline at end of file diff --git a/diffusers/src/diffusers/pipelines/deprecated/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/__init__.py new file mode 100755 index 0000000..9936323 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/__init__.py @@ -0,0 +1,153 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_librosa_available, + is_note_seq_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_pt_objects + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] + _import_structure["pndm"] = ["PNDMPipeline"] + _import_structure["repaint"] = ["RePaintPipeline"] + _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] + _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["alt_diffusion"] = [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AltDiffusionPipelineOutput", + ] + _import_structure["versatile_diffusion"] = [ + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + ] + _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] + _import_structure["stable_diffusion_variants"] = [ + "CycleDiffusionPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionModelEditingPipeline", + ] + +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_librosa_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) + +else: + _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) + +else: + _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_pt_objects import * + + else: + from .latent_diffusion_uncond import LDMPipeline + from .pndm import PNDMPipeline + from .repaint import RePaintPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline + + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AltDiffusionPipelineOutput + from .audio_diffusion import AudioDiffusionPipeline, Mel + from .spectrogram_diffusion import SpectrogramDiffusionPipeline + from .stable_diffusion_variants import ( + CycleDiffusionPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionModelEditingPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPix2PixZeroPipeline, + ) + from .stochastic_karras_ve import KarrasVePipeline + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_librosa_objects import * + else: + from .audio_diffusion import AudioDiffusionPipeline, Mel + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + else: + from .spectrogram_diffusion import ( + MidiProcessor, + SpectrogramDiffusionPipeline, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/__init__.py new file mode 100755 index 0000000..71fa15b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation"] + _import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"] + _import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"] + + _import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_roberta_series import RobertaSeriesModelWithTransformation + from .pipeline_alt_diffusion import AltDiffusionPipeline + from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline + from .pipeline_output import AltDiffusionPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py new file mode 100755 index 0000000..f69f905 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput + + +@dataclass +class TransformationModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projection_state: Optional[torch.Tensor] = None + last_hidden_state: torch.Tensor = None + hidden_states: Optional[Tuple[torch.Tensor]] = None + attentions: Optional[Tuple[torch.Tensor]] = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + r""" """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, + return_dict=return_dict, + ) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py new file mode 100755 index 0000000..d6730ee --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py @@ -0,0 +1,974 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .modeling_roberta_series import RobertaSeriesModelWithTransformation +from .pipeline_output import AltDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AltDiffusionPipeline + + >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" + >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AltDiffusionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.XLMRobertaTokenizer`]): + A `XLMRobertaTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py new file mode 100755 index 0000000..6fbf5cc --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -0,0 +1,1041 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .modeling_roberta_series import RobertaSeriesModelWithTransformation +from .pipeline_output import AltDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import AltDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "BAAI/AltDiffusion-m9" + >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> # "A fantasy landscape, trending on artstation" + >>> prompt = "幻想风景, artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("幻想风景.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AltDiffusionImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.XLMRobertaTokenizer`]): + A `XLMRobertaTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py new file mode 100755 index 0000000..dd174ae --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ....utils import ( + BaseOutput, +) + + +@dataclass +# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt +class AltDiffusionPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`List[bool]`) + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] diff --git a/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/__init__.py new file mode 100755 index 0000000..3127951 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/__init__.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = { + "mel": ["Mel"], + "pipeline_audio_diffusion": ["AudioDiffusionPipeline"], +} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .mel import Mel + from .pipeline_audio_diffusion import AudioDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/mel.py b/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/mel.py new file mode 100755 index 0000000..3426c3a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/mel.py @@ -0,0 +1,179 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np # noqa: E402 + +from ....configuration_utils import ConfigMixin, register_to_config +from ....schedulers.scheduling_utils import SchedulerMixin + + +try: + import librosa # noqa: E402 + + _librosa_can_be_imported = True + _import_error = "" +except Exception as e: + _librosa_can_be_imported = False + _import_error = ( + f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." + ) + + +from PIL import Image # noqa: E402 + + +class Mel(ConfigMixin, SchedulerMixin): + """ + Parameters: + x_res (`int`): + x resolution of spectrogram (time). + y_res (`int`): + y resolution of spectrogram (frequency bins). + sample_rate (`int`): + Sample rate of audio. + n_fft (`int`): + Number of Fast Fourier Transforms. + hop_length (`int`): + Hop length (a higher number is recommended if `y_res` < 256). + top_db (`int`): + Loudest decibel value. + n_iter (`int`): + Number of iterations for Griffin-Lim Mel inversion. + """ + + config_name = "mel_config.json" + + @register_to_config + def __init__( + self, + x_res: int = 256, + y_res: int = 256, + sample_rate: int = 22050, + n_fft: int = 2048, + hop_length: int = 512, + top_db: int = 80, + n_iter: int = 32, + ): + self.hop_length = hop_length + self.sr = sample_rate + self.n_fft = n_fft + self.top_db = top_db + self.n_iter = n_iter + self.set_resolution(x_res, y_res) + self.audio = None + + if not _librosa_can_be_imported: + raise ValueError(_import_error) + + def set_resolution(self, x_res: int, y_res: int): + """Set resolution. + + Args: + x_res (`int`): + x resolution of spectrogram (time). + y_res (`int`): + y resolution of spectrogram (frequency bins). + """ + self.x_res = x_res + self.y_res = y_res + self.n_mels = self.y_res + self.slice_size = self.x_res * self.hop_length - 1 + + def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): + """Load audio. + + Args: + audio_file (`str`): + An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. + raw_audio (`np.ndarray`): + The raw audio file as a NumPy array. + """ + if audio_file is not None: + self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) + else: + self.audio = raw_audio + + # Pad with silence if necessary. + if len(self.audio) < self.x_res * self.hop_length: + self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) + + def get_number_of_slices(self) -> int: + """Get number of slices in audio. + + Returns: + `int`: + Number of spectograms audio can be sliced into. + """ + return len(self.audio) // self.slice_size + + def get_audio_slice(self, slice: int = 0) -> np.ndarray: + """Get slice of audio. + + Args: + slice (`int`): + Slice number of audio (out of `get_number_of_slices()`). + + Returns: + `np.ndarray`: + The audio slice as a NumPy array. + """ + return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] + + def get_sample_rate(self) -> int: + """Get sample rate. + + Returns: + `int`: + Sample rate of audio. + """ + return self.sr + + def audio_slice_to_image(self, slice: int) -> Image.Image: + """Convert slice of audio to spectrogram. + + Args: + slice (`int`): + Slice number of audio to convert (out of `get_number_of_slices()`). + + Returns: + `PIL Image`: + A grayscale image of `x_res x y_res`. + """ + S = librosa.feature.melspectrogram( + y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels + ) + log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) + bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) + image = Image.fromarray(bytedata) + return image + + def image_to_audio(self, image: Image.Image) -> np.ndarray: + """Converts spectrogram to audio. + + Args: + image (`PIL Image`): + An grayscale image of `x_res x y_res`. + + Returns: + audio (`np.ndarray`): + The audio as a NumPy array. + """ + bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) + log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db + S = librosa.db_to_power(log_S) + audio = librosa.feature.inverse.mel_to_audio( + S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter + ) + return audio diff --git a/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py b/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py new file mode 100755 index 0000000..47044e0 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py @@ -0,0 +1,329 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from math import acos, sin +from typing import List, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import DDIMScheduler, DDPMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput +from .mel import Mel + + +class AudioDiffusionPipeline(DiffusionPipeline): + """ + Pipeline for audio diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + mel ([`Mel`]): + Transform audio into a spectrogram. + scheduler ([`DDIMScheduler`] or [`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`] or [`DDPMScheduler`]. + """ + + _optional_components = ["vqvae"] + + def __init__( + self, + vqvae: AutoencoderKL, + unet: UNet2DConditionModel, + mel: Mel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + ): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) + + def get_default_steps(self) -> int: + """Returns default number of steps recommended for inference. + + Returns: + `int`: + The number of steps. + """ + return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + audio_file: str = None, + raw_audio: np.ndarray = None, + slice: int = 0, + start_step: int = 0, + steps: int = None, + generator: torch.Generator = None, + mask_start_secs: float = 0, + mask_end_secs: float = 0, + step_generator: torch.Generator = None, + eta: float = 0, + noise: torch.Tensor = None, + encoding: torch.Tensor = None, + return_dict=True, + ) -> Union[ + Union[AudioPipelineOutput, ImagePipelineOutput], + Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]], + ]: + """ + The call function to the pipeline for generation. + + Args: + batch_size (`int`): + Number of samples to generate. + audio_file (`str`): + An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. + raw_audio (`np.ndarray`): + The raw audio file as a NumPy array. + slice (`int`): + Slice number of audio to convert. + start_step (int): + Step to start diffusion from. + steps (`int`): + Number of denoising steps (defaults to `50` for DDIM and `1000` for DDPM). + generator (`torch.Generator`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + mask_start_secs (`float`): + Number of seconds of audio to mask (not generate) at start. + mask_end_secs (`float`): + Number of seconds of audio to mask (not generate) at end. + step_generator (`torch.Generator`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) used to denoise. + None + eta (`float`): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + noise (`torch.Tensor`): + A noise tensor of shape `(batch_size, 1, height, width)` or `None`. + encoding (`torch.Tensor`): + A tensor for [`UNet2DConditionModel`] of shape `(batch_size, seq_length, cross_attention_dim)`. + return_dict (`bool`): + Whether or not to return a [`AudioPipelineOutput`], [`ImagePipelineOutput`] or a plain tuple. + + Examples: + + For audio diffusion: + + ```py + import torch + from IPython.display import Audio + from diffusers import DiffusionPipeline + + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to(device) + + output = pipe() + display(output.images[0]) + display(Audio(output.audios[0], rate=mel.get_sample_rate())) + ``` + + For latent audio diffusion: + + ```py + import torch + from IPython.display import Audio + from diffusers import DiffusionPipeline + + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = DiffusionPipeline.from_pretrained("teticio/latent-audio-diffusion-256").to(device) + + output = pipe() + display(output.images[0]) + display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) + ``` + + For other tasks like variation, inpainting, outpainting, etc: + + ```py + output = pipe( + raw_audio=output.audios[0, 0], + start_step=int(pipe.get_default_steps() / 2), + mask_start_secs=1, + mask_end_secs=1, + ) + display(output.images[0]) + display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) + ``` + + Returns: + `List[PIL Image]`: + A list of Mel spectrograms (`float`, `List[np.ndarray]`) with the sample rate and raw audio. + """ + + steps = steps or self.get_default_steps() + self.scheduler.set_timesteps(steps) + step_generator = step_generator or generator + # For backwards compatibility + if isinstance(self.unet.config.sample_size, int): + self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) + if noise is None: + noise = randn_tensor( + ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size[0], + self.unet.config.sample_size[1], + ), + generator=generator, + device=self.device, + ) + images = noise + mask = None + + if audio_file is not None or raw_audio is not None: + self.mel.load_audio(audio_file, raw_audio) + input_image = self.mel.audio_slice_to_image(slice) + input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( + (input_image.height, input_image.width) + ) + input_image = (input_image / 255) * 2 - 1 + input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) + + if self.vqvae is not None: + input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( + generator=generator + )[0] + input_images = self.vqvae.config.scaling_factor * input_images + + if start_step > 0: + images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) + + pixels_per_second = ( + self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length + ) + mask_start = int(mask_start_secs * pixels_per_second) + mask_end = int(mask_end_secs * pixels_per_second) + mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) + + for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): + if isinstance(self.unet, UNet2DConditionModel): + model_output = self.unet(images, t, encoding)["sample"] + else: + model_output = self.unet(images, t)["sample"] + + if isinstance(self.scheduler, DDIMScheduler): + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + eta=eta, + generator=step_generator, + )["prev_sample"] + else: + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + generator=step_generator, + )["prev_sample"] + + if mask is not None: + if mask_start > 0: + images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] + if mask_end > 0: + images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] + + if self.vqvae is not None: + # 0.18215 was scaling factor used in training to ensure unit variance + images = 1 / self.vqvae.config.scaling_factor * images + images = self.vqvae.decode(images)["sample"] + + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).numpy() + images = (images * 255).round().astype("uint8") + images = list( + (Image.fromarray(_[:, :, 0]) for _ in images) + if images.shape[3] == 1 + else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) + ) + + audios = [self.mel.image_to_audio(_) for _ in images] + if not return_dict: + return images, (self.mel.get_sample_rate(), audios) + + return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) + + @torch.no_grad() + def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: + """ + Reverse the denoising step process to recover a noisy image from the generated image. + + Args: + images (`List[PIL Image]`): + List of images to encode. + steps (`int`): + Number of encoding steps to perform (defaults to `50`). + + Returns: + `np.ndarray`: + A noise tensor of shape `(batch_size, 1, height, width)`. + """ + + # Only works with DDIM as this method is deterministic + assert isinstance(self.scheduler, DDIMScheduler) + self.scheduler.set_timesteps(steps) + sample = np.array( + [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] + ) + sample = (sample / 255) * 2 - 1 + sample = torch.Tensor(sample).to(self.device) + + for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): + prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[t] + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.scheduler.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + model_output = self.unet(sample, t)["sample"] + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output + sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) + sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output + + return sample + + @staticmethod + def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: + """Spherical Linear intERPolation. + + Args: + x0 (`torch.Tensor`): + The first tensor to interpolate between. + x1 (`torch.Tensor`): + Second tensor to interpolate between. + alpha (`float`): + Interpolation between 0 and 1 + + Returns: + `torch.Tensor`: + The interpolated tensor. + """ + + theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) + return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) diff --git a/diffusers/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py new file mode 100755 index 0000000..214f5bb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_latent_diffusion_uncond": ["LDMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_latent_diffusion_uncond import LDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/diffusers/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py new file mode 100755 index 0000000..7fe5d59 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -0,0 +1,130 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch + +from ....models import UNet2DModel, VQModel +from ....schedulers import DDIMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class LDMPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + [`DDIMScheduler`] is used in combination with `unet` to denoise the encoded image latents. + """ + + def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import LDMPipeline + + >>> # load model and scheduler + >>> pipe = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + latents = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + ) + latents = latents.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + latent_model_input = self.scheduler.scale_model_input(latents, t) + # predict the noise residual + noise_prediction = self.unet(latent_model_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + + # adjust latents with inverse of vae scale + latents = latents / self.vqvae.config.scaling_factor + # decode the image latents with the VAE + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/pndm/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/pndm/__init__.py new file mode 100755 index 0000000..5e3bdba --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/pndm/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_pndm": ["PNDMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_pndm import PNDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py b/diffusers/src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py new file mode 100755 index 0000000..ef78af1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py @@ -0,0 +1,121 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch + +from ....models import UNet2DModel +from ....schedulers import PNDMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class PNDMPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`PNDMScheduler`]): + A `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: PNDMScheduler + + def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): + super().__init__() + + scheduler = PNDMScheduler.from_config(scheduler.config) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, `optional`, defaults to 1): + The number of images to generate. + num_inference_steps (`int`, `optional`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator`, `optional`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import PNDMPipeline + + >>> # load model and scheduler + >>> pndm = PNDMPipeline.from_pretrained("google/ddpm-cifar10-32") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pndm().images[0] + + >>> # save image + >>> image.save("pndm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://arxiv.org/pdf/2202.09778.pdf + + # Sample gaussian noise to begin loop + image = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + device=self.device, + ) + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + model_output = self.unet(image, t).sample + + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/repaint/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/repaint/__init__.py new file mode 100755 index 0000000..2c6b04a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/repaint/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_repaint": ["RePaintPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_repaint import RePaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py b/diffusers/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py new file mode 100755 index 0000000..101d315 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py @@ -0,0 +1,230 @@ +# Copyright 2024 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from ....models import UNet2DModel +from ....schedulers import RePaintScheduler +from ....utils import PIL_INTERPOLATION, deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): + if isinstance(mask, torch.Tensor): + return mask + elif isinstance(mask, PIL.Image.Image): + mask = [mask] + + if isinstance(mask[0], PIL.Image.Image): + w, h = mask[0].size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] + mask = np.concatenate(mask, axis=0) + mask = mask.astype(np.float32) / 255.0 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.cat(mask, dim=0) + return mask + + +class RePaintPipeline(DiffusionPipeline): + r""" + Pipeline for image inpainting using RePaint. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`RePaintScheduler`]): + A `RePaintScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: RePaintScheduler + model_cpu_offload_seq = "unet" + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: Union[torch.Tensor, PIL.Image.Image], + mask_image: Union[torch.Tensor, PIL.Image.Image], + num_inference_steps: int = 250, + eta: float = 0.0, + jump_length: int = 10, + jump_n_sample: int = 10, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + The original image to inpaint on. + mask_image (`torch.Tensor` or `PIL.Image.Image`): + The mask_image where 0.0 define which part of the original image to inpaint. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`): + The weight of the added noise in a diffusion step. Its value is between 0.0 and 1.0; 0.0 corresponds to + DDIM and 1.0 is the DDPM scheduler. + jump_length (`int`, *optional*, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump ("j" in + RePaint paper). Take a look at Figure 9 and 10 in the [paper](https://arxiv.org/pdf/2201.09865.pdf). + jump_n_sample (`int`, *optional*, defaults to 10): + The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9 + and 10 in the [paper](https://arxiv.org/pdf/2201.09865.pdf). + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from io import BytesIO + >>> import torch + >>> import PIL + >>> import requests + >>> from diffusers import RePaintPipeline, RePaintScheduler + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" + >>> mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" + + >>> # Load the original image and the mask as PIL images + >>> original_image = download_image(img_url).resize((256, 256)) + >>> mask_image = download_image(mask_url).resize((256, 256)) + + >>> # Load the RePaint scheduler and pipeline based on a pretrained DDPM model + >>> scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") + >>> pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> output = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... num_inference_steps=250, + ... eta=0.0, + ... jump_length=10, + ... jump_n_sample=10, + ... generator=generator, + ... ) + >>> inpainted_image = output.images[0] + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + original_image = image + + original_image = _preprocess_image(original_image) + original_image = original_image.to(device=self._execution_device, dtype=self.unet.dtype) + mask_image = _preprocess_mask(mask_image) + mask_image = mask_image.to(device=self._execution_device, dtype=self.unet.dtype) + + batch_size = original_image.shape[0] + + # sample gaussian noise to begin the loop + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image_shape = original_image.shape + image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self._execution_device) + self.scheduler.eta = eta + + t_last = self.scheduler.timesteps[0] + 1 + generator = generator[0] if isinstance(generator, list) else generator + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + if t < t_last: + # predict the noise residual + model_output = self.unet(image, t).sample + # compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample + + else: + # compute the reverse: x_t-1 -> x_t + image = self.scheduler.undo_step(image, t_last, generator) + t_last = t + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/score_sde_ve/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/score_sde_ve/__init__.py new file mode 100755 index 0000000..87c167c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/score_sde_ve/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_score_sde_ve": ["ScoreSdeVePipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_score_sde_ve import ScoreSdeVePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py b/diffusers/src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py new file mode 100755 index 0000000..b0bb114 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py @@ -0,0 +1,109 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch + +from ....models import UNet2DModel +from ....schedulers import ScoreSdeVeScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class ScoreSdeVePipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`ScoreSdeVeScheduler`]): + A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, `optional`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) + + # correction step + for _ in range(self.scheduler.config.correct_steps): + model_output = self.unet(sample, sigma_t).sample + sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample + + # prediction step + model_output = model(sample, sigma_t).sample + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) + + sample, sample_mean = output.prev_sample, output.prev_sample_mean + + sample = sample_mean.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + if not return_dict: + return (sample,) + + return ImagePipelineOutput(images=sample) diff --git a/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py new file mode 100755 index 0000000..150954b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py @@ -0,0 +1,75 @@ +# flake8: noqa +from typing import TYPE_CHECKING +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, + is_note_seq_available, + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + get_objects_from_module, +) + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["continous_encoder"] = ["SpectrogramContEncoder"] + _import_structure["notes_encoder"] = ["SpectrogramNotesEncoder"] + _import_structure["pipeline_spectrogram_diffusion"] = [ + "SpectrogramContEncoder", + "SpectrogramDiffusionPipeline", + "T5FilmDecoder", + ] +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_transformers_and_torch_and_note_seq_objects + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) +else: + _import_structure["midi_utils"] = ["MidiProcessor"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_spectrogram_diffusion import SpectrogramDiffusionPipeline + from .pipeline_spectrogram_diffusion import SpectrogramContEncoder + from .pipeline_spectrogram_diffusion import SpectrogramNotesEncoder + from .pipeline_spectrogram_diffusion import T5FilmDecoder + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_transformers_and_torch_and_note_seq_objects import * + + else: + from .midi_utils import MidiProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py new file mode 100755 index 0000000..8664c2f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py @@ -0,0 +1,92 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import ( + T5Block, + T5Config, + T5LayerNorm, +) + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin + + +class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + input_dims: int, + targets_context_length: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.input_proj = nn.Linear(input_dims, d_model, bias=False) + + self.position_encoding = nn.Embedding(targets_context_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + feed_forward_proj=feed_forward_proj, + dropout_rate=dropout_rate, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_inputs, encoder_inputs_mask): + x = self.input_proj(encoder_inputs) + + # terminal relative positional encodings + max_positions = encoder_inputs.shape[1] + input_positions = torch.arange(max_positions, device=encoder_inputs.device) + + seq_lens = encoder_inputs_mask.sum(-1) + input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) + x += self.position_encoding(input_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_inputs.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py new file mode 100755 index 0000000..e777e84 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py @@ -0,0 +1,667 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import math +import os +from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ....utils import is_note_seq_available +from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH + + +if is_note_seq_available(): + import note_seq +else: + raise ImportError("Please install note-seq via `pip install note-seq`") + + +INPUT_FEATURE_LENGTH = 2048 + +SAMPLE_RATE = 16000 +HOP_SIZE = 320 +FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) + +DEFAULT_STEPS_PER_SECOND = 100 +DEFAULT_MAX_SHIFT_SECONDS = 10 +DEFAULT_NUM_VELOCITY_BINS = 1 + +SLAKH_CLASS_PROGRAMS = { + "Acoustic Piano": 0, + "Electric Piano": 4, + "Chromatic Percussion": 8, + "Organ": 16, + "Acoustic Guitar": 24, + "Clean Electric Guitar": 26, + "Distorted Electric Guitar": 29, + "Acoustic Bass": 32, + "Electric Bass": 33, + "Violin": 40, + "Viola": 41, + "Cello": 42, + "Contrabass": 43, + "Orchestral Harp": 46, + "Timpani": 47, + "String Ensemble": 48, + "Synth Strings": 50, + "Choir and Voice": 52, + "Orchestral Hit": 55, + "Trumpet": 56, + "Trombone": 57, + "Tuba": 58, + "French Horn": 60, + "Brass Section": 61, + "Soprano/Alto Sax": 64, + "Tenor Sax": 66, + "Baritone Sax": 67, + "Oboe": 68, + "English Horn": 69, + "Bassoon": 70, + "Clarinet": 71, + "Pipe": 73, + "Synth Lead": 80, + "Synth Pad": 88, +} + + +@dataclasses.dataclass +class NoteRepresentationConfig: + """Configuration note representations.""" + + onsets_only: bool + include_ties: bool + + +@dataclasses.dataclass +class NoteEventData: + pitch: int + velocity: Optional[int] = None + program: Optional[int] = None + is_drum: Optional[bool] = None + instrument: Optional[int] = None + + +@dataclasses.dataclass +class NoteEncodingState: + """Encoding state for note transcription, keeping track of active pitches.""" + + # velocity bin for active pitches and programs + active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class EventRange: + type: str + min_value: int + max_value: int + + +@dataclasses.dataclass +class Event: + type: str + value: int + + +class Tokenizer: + def __init__(self, regular_ids: int): + # The special tokens: 0=PAD, 1=EOS, and 2=UNK + self._num_special_tokens = 3 + self._num_regular_tokens = regular_ids + + def encode(self, token_ids): + encoded = [] + for token_id in token_ids: + if not 0 <= token_id < self._num_regular_tokens: + raise ValueError( + f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" + ) + encoded.append(token_id + self._num_special_tokens) + + # Add EOS token + encoded.append(1) + + # Pad to till INPUT_FEATURE_LENGTH + encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) + + return encoded + + +class Codec: + """Encode and decode events. + + Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from + Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not + include things like EOS or UNK token handling. + + To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required + and specified separately. + """ + + def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]): + """Define Codec. + + Args: + max_shift_steps: Maximum number of shift steps that can be encoded. + steps_per_second: Shift steps will be interpreted as having a duration of + 1 / steps_per_second. + event_ranges: Other supported event types and their ranges. + """ + self.steps_per_second = steps_per_second + self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) + self._event_ranges = [self._shift_range] + event_ranges + # Ensure all event types have unique names. + assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) + + @property + def num_classes(self) -> int: + return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) + + # The next couple methods are simplified special case methods just for shift + # events that are intended to be used from within autograph functions. + + def is_shift_event_index(self, index: int) -> bool: + return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) + + @property + def max_shift_steps(self) -> int: + return self._shift_range.max_value + + def encode_event(self, event: Event) -> int: + """Encode an event to an index.""" + offset = 0 + for er in self._event_ranges: + if event.type == er.type: + if not er.min_value <= event.value <= er.max_value: + raise ValueError( + f"Event value {event.value} is not within valid range " + f"[{er.min_value}, {er.max_value}] for type {event.type}" + ) + return offset + event.value - er.min_value + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event.type}") + + def event_type_range(self, event_type: str) -> Tuple[int, int]: + """Return [min_id, max_id] for an event type.""" + offset = 0 + for er in self._event_ranges: + if event_type == er.type: + return offset, offset + (er.max_value - er.min_value) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event_type}") + + def decode_event_index(self, index: int) -> Event: + """Decode an event index to an Event.""" + offset = 0 + for er in self._event_ranges: + if offset <= index <= offset + er.max_value - er.min_value: + return Event(type=er.type, value=er.min_value + index - offset) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event index: {index}") + + +@dataclasses.dataclass +class ProgramGranularity: + # both tokens_map_fn and program_map_fn should be idempotent + tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] + program_map_fn: Callable[[int], int] + + +def drop_programs(tokens, codec: Codec): + """Drops program change events from a token sequence.""" + min_program_id, max_program_id = codec.event_type_range("program") + return tokens[(tokens < min_program_id) | (tokens > max_program_id)] + + +def programs_to_midi_classes(tokens, codec): + """Modifies program events to be the first program in the MIDI class.""" + min_program_id, max_program_id = codec.event_type_range("program") + is_program = (tokens >= min_program_id) & (tokens <= max_program_id) + return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) + + +PROGRAM_GRANULARITIES = { + # "flat" granularity; drop program change tokens and set NoteSequence + # programs to zero + "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), + # map each program to the first program in its MIDI class + "midi_class": ProgramGranularity( + tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) + ), + # leave programs as is + "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), +} + + +def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): + """ + equivalent of tf.signal.frame + """ + signal_length = signal.shape[axis] + if pad_end: + frames_overlap = frame_length - frame_step + rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) + pad_size = int(frame_length - rest_samples) + + if pad_size != 0: + pad_axis = [0] * signal.ndim + pad_axis[axis] = pad_size + signal = F.pad(signal, pad_axis, "constant", pad_value) + frames = signal.unfold(axis, frame_length, frame_step) + return frames + + +def program_to_slakh_program(program): + # this is done very hackily, probably should use a custom mapping + for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): + if program >= slakh_program: + return slakh_program + + +def audio_to_frames( + samples, + hop_size: int, + frame_rate: int, +) -> Tuple[Sequence[Sequence[int]], torch.Tensor]: + """Convert audio samples to non-overlapping frames and frame times.""" + frame_size = hop_size + samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") + + # Split audio into frames. + frames = frame( + torch.Tensor(samples).unsqueeze(0), + frame_length=frame_size, + frame_step=frame_size, + pad_end=False, # TODO check why its off by 1 here when True + ) + + num_frames = len(samples) // frame_size + + times = np.arange(num_frames) / frame_rate + return frames, times + + +def note_sequence_to_onsets_and_offsets_and_programs( + ns: note_seq.NoteSequence, +) -> Tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract onset & offset times and pitches & programs from a NoteSequence. + + The onset & offset times will not necessarily be in sorted order. + + Args: + ns: NoteSequence from which to extract onsets and offsets. + + Returns: + times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for + note + offsets. + """ + # Sort by program and pitch and put offsets before onsets as a tiebreaker for + # subsequent stable sort. + notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) + times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] + values = [ + NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) + for note in notes + if not note.is_drum + ] + [ + NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) + for note in notes + ] + return times, values + + +def num_velocity_bins_from_codec(codec: Codec): + """Get number of velocity bins from event codec.""" + lo, hi = codec.event_type_range("velocity") + return hi - lo + + +# segment an array into segments of length n +def segment(a, n): + return [a[i : i + n] for i in range(0, len(a), n)] + + +def velocity_to_bin(velocity, num_velocity_bins): + if velocity == 0: + return 0 + else: + return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) + + +def note_event_data_to_events( + state: Optional[NoteEncodingState], + value: NoteEventData, + codec: Codec, +) -> Sequence[Event]: + """Convert note event data to a sequence of events.""" + if value.velocity is None: + # onsets only, no program or velocity + return [Event("pitch", value.pitch)] + else: + num_velocity_bins = num_velocity_bins_from_codec(codec) + velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) + if value.program is None: + # onsets + offsets + velocities only, no programs + if state is not None: + state.active_pitches[(value.pitch, 0)] = velocity_bin + return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] + else: + if value.is_drum: + # drum events use a separate vocabulary + return [Event("velocity", velocity_bin), Event("drum", value.pitch)] + else: + # program + velocity + pitch + if state is not None: + state.active_pitches[(value.pitch, value.program)] = velocity_bin + return [ + Event("program", value.program), + Event("velocity", velocity_bin), + Event("pitch", value.pitch), + ] + + +def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: + """Output program and pitch events for active notes plus a final tie event.""" + events = [] + for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): + if state.active_pitches[(pitch, program)]: + events += [Event("program", program), Event("pitch", pitch)] + events.append(Event("tie", 0)) + return events + + +def encode_and_index_events( + state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None +): + """Encode a sequence of timed events and index to audio frame times. + + Encodes time shifts as repeated single step shifts for later run length encoding. + + Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio + frame. This can be used e.g. to prepend events representing the current state to a targets segment. + + Args: + state: Initial event encoding state. + event_times: Sequence of event times. + event_values: Sequence of event values. + encode_event_fn: Function that transforms event value into a sequence of one + or more Event objects. + codec: An Codec object that maps Event objects to indices. + frame_times: Time for every audio frame. + encoding_state_to_events_fn: Function that transforms encoding state into a + sequence of one or more Event objects. + + Returns: + events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. + Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes + splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of + another. + event_end_indices: Corresponding end event index for every audio frame. Used + to ensure when slicing that one chunk ends where the next begins. Should always be true that + event_end_indices[i] = event_start_indices[i + 1]. + state_events: Encoded "state" events representing the encoding state before + each event. + state_event_indices: Corresponding state event index for every audio frame. + """ + indices = np.argsort(event_times, kind="stable") + event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] + event_values = [event_values[i] for i in indices] + + events = [] + state_events = [] + event_start_indices = [] + state_event_indices = [] + + cur_step = 0 + cur_event_idx = 0 + cur_state_event_idx = 0 + + def fill_event_start_indices_to_cur_step(): + while ( + len(event_start_indices) < len(frame_times) + and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second + ): + event_start_indices.append(cur_event_idx) + state_event_indices.append(cur_state_event_idx) + + for event_step, event_value in zip(event_steps, event_values): + while event_step > cur_step: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + cur_state_event_idx = len(state_events) + if encoding_state_to_events_fn: + # Dump state to state events *before* processing the next event, because + # we want to capture the state prior to the occurrence of the event. + for e in encoding_state_to_events_fn(state): + state_events.append(codec.encode_event(e)) + + for e in encode_event_fn(state, event_value, codec): + events.append(codec.encode_event(e)) + + # After the last event, continue filling out the event_start_indices array. + # The inequality is not strict because if our current step lines up exactly + # with (the start of) an audio frame, we need to add an additional shift event + # to "cover" that frame. + while cur_step / codec.steps_per_second <= frame_times[-1]: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + + # Now fill in event_end_indices. We need this extra array to make sure that + # when we slice events, each slice ends exactly where the subsequent slice + # begins. + event_end_indices = event_start_indices[1:] + [len(events)] + + events = np.array(events).astype(np.int32) + state_events = np.array(state_events).astype(np.int32) + event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + + outputs = [] + for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): + outputs.append( + { + "inputs": events, + "event_start_indices": start_indices, + "event_end_indices": end_indices, + "state_events": state_events, + "state_event_indices": event_indices, + } + ) + + return outputs + + +def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): + """Extract target sequence corresponding to audio token segment.""" + features = features.copy() + start_idx = features["event_start_indices"][0] + end_idx = features["event_end_indices"][-1] + + features[feature_key] = features[feature_key][start_idx:end_idx] + + if state_events_end_token is not None: + # Extract the state events corresponding to the audio start token, and + # prepend them to the targets array. + state_event_start_idx = features["state_event_indices"][0] + state_event_end_idx = state_event_start_idx + 1 + while features["state_events"][state_event_end_idx - 1] != state_events_end_token: + state_event_end_idx += 1 + features[feature_key] = np.concatenate( + [ + features["state_events"][state_event_start_idx:state_event_end_idx], + features[feature_key], + ], + axis=0, + ) + + return features + + +def map_midi_programs( + feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" +) -> Mapping[str, Any]: + """Apply MIDI program map to token sequences.""" + granularity = PROGRAM_GRANULARITIES[granularity_type] + + feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) + return feature + + +def run_length_encode_shifts_fn( + features, + codec: Codec, + feature_key: str = "inputs", + state_change_event_types: Sequence[str] = (), +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + """Return a function that run-length encodes shifts for a given codec. + + Args: + codec: The Codec to use for shift events. + feature_key: The feature key for which to run-length encode shifts. + state_change_event_types: A list of event types that represent state + changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones + will be removed. + + Returns: + A preprocessing function that run-length encodes single-step shifts. + """ + state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] + + def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: + """Combine leading/interior shifts, trim trailing shifts. + + Args: + features: Dict of features to process. + + Returns: + A dict of features. + """ + events = features[feature_key] + + shift_steps = 0 + total_shift_steps = 0 + output = np.array([], dtype=np.int32) + + current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) + + for event in events: + if codec.is_shift_event_index(event): + shift_steps += 1 + total_shift_steps += 1 + + else: + # If this event is a state change and has the same value as the current + # state, we can skip it entirely. + is_redundant = False + for i, (min_index, max_index) in enumerate(state_change_event_ranges): + if (min_index <= event) and (event <= max_index): + if current_state[i] == event: + is_redundant = True + current_state[i] = event + if is_redundant: + continue + + # Once we've reached a non-shift event, RLE all previous shift events + # before outputting the non-shift event. + if shift_steps > 0: + shift_steps = total_shift_steps + while shift_steps > 0: + output_steps = np.minimum(codec.max_shift_steps, shift_steps) + output = np.concatenate([output, [output_steps]], axis=0) + shift_steps -= output_steps + output = np.concatenate([output, [event]], axis=0) + + features[feature_key] = output + return features + + return run_length_encode_shifts(features) + + +def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): + tie_token = codec.encode_event(Event("tie", 0)) + state_events_end_token = tie_token if note_representation_config.include_ties else None + + features = extract_sequence_with_indices( + features, state_events_end_token=state_events_end_token, feature_key="inputs" + ) + + features = map_midi_programs(features, codec) + + features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) + + return features + + +class MidiProcessor: + def __init__(self): + self.codec = Codec( + max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, + steps_per_second=DEFAULT_STEPS_PER_SECOND, + event_ranges=[ + EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), + EventRange("tie", 0, 0), + EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), + EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + ], + ) + self.tokenizer = Tokenizer(self.codec.num_classes) + self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) + + def __call__(self, midi: Union[bytes, os.PathLike, str]): + if not isinstance(midi, bytes): + with open(midi, "rb") as f: + midi = f.read() + + ns = note_seq.midi_to_note_sequence(midi) + ns_sus = note_seq.apply_sustain_control_changes(ns) + + for note in ns_sus.notes: + if not note.is_drum: + note.program = program_to_slakh_program(note.program) + + samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) + + _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) + times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) + + events = encode_and_index_events( + state=NoteEncodingState(), + event_times=times, + event_values=values, + frame_times=frame_times, + codec=self.codec, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=note_encoding_state_to_events, + ) + + events = [ + note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events + ] + input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] + + return input_tokens diff --git a/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py new file mode 100755 index 0000000..1259f0b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py @@ -0,0 +1,86 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin + + +class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + max_length: int, + vocab_size: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.token_embedder = nn.Embedding(vocab_size, d_model) + + self.position_encoding = nn.Embedding(max_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + vocab_size=vocab_size, + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + dropout_rate=dropout_rate, + feed_forward_proj=feed_forward_proj, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_input_tokens, encoder_inputs_mask): + x = self.token_embedder(encoder_input_tokens) + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) + x += self.position_encoding(inputs_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_input_tokens.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py new file mode 100755 index 0000000..b8ac8e1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py @@ -0,0 +1,269 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ....models import T5FilmDecoder +from ....schedulers import DDPMScheduler +from ....utils import is_onnx_available, logging +from ....utils.torch_utils import randn_tensor + + +if is_onnx_available(): + from ...onnx_utils import OnnxRuntimeModel + +from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .continuous_encoder import SpectrogramContEncoder +from .notes_encoder import SpectrogramNotesEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TARGET_FEATURE_LENGTH = 256 + + +class SpectrogramDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional audio generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + notes_encoder ([`SpectrogramNotesEncoder`]): + continuous_encoder ([`SpectrogramContEncoder`]): + decoder ([`T5FilmDecoder`]): + A [`T5FilmDecoder`] to denoise the encoded audio latents. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with `decoder` to denoise the encoded audio latents. + melgan ([`OnnxRuntimeModel`]): + """ + + _optional_components = ["melgan"] + + def __init__( + self, + notes_encoder: SpectrogramNotesEncoder, + continuous_encoder: SpectrogramContEncoder, + decoder: T5FilmDecoder, + scheduler: DDPMScheduler, + melgan: OnnxRuntimeModel if is_onnx_available() else Any, + ) -> None: + super().__init__() + + # From MELGAN + self.min_value = math.log(1e-5) # Matches MelGAN training. + self.max_value = 4.0 # Largest value for most examples + self.n_dims = 128 + + self.register_modules( + notes_encoder=notes_encoder, + continuous_encoder=continuous_encoder, + decoder=decoder, + scheduler=scheduler, + melgan=melgan, + ) + + def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): + """Linearly scale features to network outputs range.""" + min_out, max_out = output_range + if clip: + features = torch.clip(features, self.min_value, self.max_value) + # Scale to [0, 1]. + zero_one = (features - self.min_value) / (self.max_value - self.min_value) + # Scale to [min_out, max_out]. + return zero_one * (max_out - min_out) + min_out + + def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): + """Invert by linearly scaling network outputs to features range.""" + min_out, max_out = input_range + outputs = torch.clip(outputs, min_out, max_out) if clip else outputs + # Scale to [0, 1]. + zero_one = (outputs - min_out) / (max_out - min_out) + # Scale to [self.min_value, self.max_value]. + return zero_one * (self.max_value - self.min_value) + self.min_value + + def encode(self, input_tokens, continuous_inputs, continuous_mask): + tokens_mask = input_tokens > 0 + tokens_encoded, tokens_mask = self.notes_encoder( + encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask + ) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask + ) + + return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] + + def decode(self, encodings_and_masks, input_tokens, noise_time): + timesteps = noise_time + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(input_tokens.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + logits = self.decoder( + encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps + ) + return logits + + @torch.no_grad() + def __call__( + self, + input_tokens: List[List[int]], + generator: Optional[torch.Generator] = None, + num_inference_steps: int = 100, + return_dict: bool = True, + output_type: str = "np", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ) -> Union[AudioPipelineOutput, Tuple]: + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + r""" + The call function to the pipeline for generation. + + Args: + input_tokens (`List[List[int]]`): + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated audio. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Example: + + ```py + >>> from diffusers import SpectrogramDiffusionPipeline, MidiProcessor + + >>> pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") + >>> pipe = pipe.to("cuda") + >>> processor = MidiProcessor() + + >>> # Download MIDI from: wget http://www.piano-midi.de/midis/beethoven/beethoven_hammerklavier_2.mid + >>> output = pipe(processor("beethoven_hammerklavier_2.mid")) + + >>> audio = output.audios[0] + ``` + + Returns: + [`pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + + pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) + full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) + ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + + for i, encoder_input_tokens in enumerate(input_tokens): + if i == 0: + encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( + device=self.device, dtype=self.decoder.dtype + ) + # The first chunk has no previous context. + encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + else: + # The full song pipeline does not feed in a context feature, so the mask + # will be all 0s after the feature converter. Because we know we're + # feeding in a full context chunk from the previous prediction, set it + # to all 1s. + encoder_continuous_mask = ones + + encoder_continuous_inputs = self.scale_features( + encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True + ) + + encodings_and_masks = self.encode( + input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + ) + + # Sample encoder_continuous_inputs shaped gaussian noise to begin loop + x = randn_tensor( + shape=encoder_continuous_inputs.shape, + generator=generator, + device=self.device, + dtype=self.decoder.dtype, + ) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + # Denoising diffusion loop + for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + output = self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=x, + noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) + ) + + # Compute previous output: x_t -> x_t-1 + x = self.scheduler.step(output, t, x, generator=generator).prev_sample + + mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) + encoder_continuous_inputs = mel[:1] + pred_mel = mel.cpu().float().numpy() + + full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, full_pred_mel) + + logger.info("Generated segment", i) + + if output_type == "np" and not is_onnx_available(): + raise ValueError( + "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." + ) + elif output_type == "np" and self.melgan is None: + raise ValueError( + "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." + ) + + if output_type == "np": + output = self.melgan(input_features=full_pred_mel.astype(np.float32)) + else: + output = full_pred_mel + + if not return_dict: + return (output,) + + return AudioPipelineOutput(audios=output) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py new file mode 100755 index 0000000..36cf1a3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] + _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] + + _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] + _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_cycle_diffusion import CycleDiffusionPipeline + from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy + from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline + from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline + from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py new file mode 100755 index 0000000..777be88 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py @@ -0,0 +1,949 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import DDIMScheduler +from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + if prev_timestep <= 0: + return clean_latents + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # direction pointing to x_t + e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) + dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t + noise = std_dev_t * randn_tensor( + clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator + ) + prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise + + return prev_latents + + +def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( + variance ** (0.5) * eta + ) + return noise + + +class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can only be an + instance of [`DDIMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + image = image.to(device=device, dtype=dtype) + + batch_size = image.shape[0] + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timestep + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + clean_latents = init_latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents, clean_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + source_prompt: Union[str, List[str]], + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + source_guidance_scale: Optional[float] = 1, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + source_guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale for the source prompt. This is useful to control the amount of influence the source + prompt has for encoding. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Example: + + ```py + import requests + import torch + from PIL import Image + from io import BytesIO + + from diffusers import CycleDiffusionPipeline, DDIMScheduler + + # load the pipeline + # make sure you're logged in with `huggingface-cli login` + model_id_or_path = "CompVis/stable-diffusion-v1-4" + scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") + pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") + + # let's download an initial image + url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((512, 512)) + init_image.save("horse.png") + + # let's specify a prompt + source_prompt = "An astronaut riding a horse" + prompt = "An astronaut riding an elephant" + + # call the pipeline + image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.8, + guidance_scale=2, + source_guidance_scale=1, + ).images[0] + + image.save("horse_to_elephant.png") + + # let's try another example + # See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion + url = ( + "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" + ) + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((512, 512)) + init_image.save("black.png") + + source_prompt = "A black colored car" + prompt = "A blue colored car" + + # call the pipeline + torch.manual_seed(0) + image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, + ).images[0] + + image.save("black_to_blue.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds_tuple = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + source_prompt_embeds_tuple = self.encode_prompt( + source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None, clip_skip=clip_skip + ) + if prompt_embeds_tuple[1] is not None: + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + else: + prompt_embeds = prompt_embeds_tuple[0] + if source_prompt_embeds_tuple[1] is not None: + source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]]) + else: + source_prompt_embeds = source_prompt_embeds_tuple[0] + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, clean_latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + source_latents = latents + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + generator = extra_step_kwargs.pop("generator", None) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + source_latent_model_input = ( + torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) + + # predict the noise residual + if do_classifier_free_guidance: + concat_latent_model_input = torch.stack( + [ + source_latent_model_input[0], + latent_model_input[0], + source_latent_model_input[1], + latent_model_input[1], + ], + dim=0, + ) + concat_prompt_embeds = torch.stack( + [ + source_prompt_embeds[0], + prompt_embeds[0], + source_prompt_embeds[1], + prompt_embeds[1], + ], + dim=0, + ) + else: + concat_latent_model_input = torch.cat( + [ + source_latent_model_input, + latent_model_input, + ], + dim=0, + ) + concat_prompt_embeds = torch.cat( + [ + source_prompt_embeds, + prompt_embeds, + ], + dim=0, + ) + + concat_noise_pred = self.unet( + concat_latent_model_input, + t, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=concat_prompt_embeds, + ).sample + + # perform guidance + if do_classifier_free_guidance: + ( + source_noise_pred_uncond, + noise_pred_uncond, + source_noise_pred_text, + noise_pred_text, + ) = concat_noise_pred.chunk(4, dim=0) + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( + source_noise_pred_text - source_noise_pred_uncond + ) + + else: + (source_noise_pred, noise_pred) = concat_noise_pred.chunk(2, dim=0) + + # Sample source_latents from the posterior distribution. + prev_source_latents = posterior_sample( + self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs + ) + # Compute noise. + noise = compute_noise( + self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs + ) + source_latents = prev_source_latents + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py new file mode 100755 index 0000000..0aa5e68 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,542 @@ +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ....utils import deprecate, logging +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + return mask + + +class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to + provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[np.ndarray, PIL.Image.Image] = None, + mask_image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, np.ndarray): + mask_image = preprocess_mask(mask_image, 8) + mask_image = mask_image.astype(latents_dtype) + mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + + latents = latents.numpy() + + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) + ) + + init_latents_proper = init_latents_proper.numpy() + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py new file mode 100755 index 0000000..ce7ad3b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,784 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import VaeImageProcessor +from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image, batch_size): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, batch_size, scale_factor=8): + if not isinstance(mask, torch.Tensor): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = np.vstack([mask[None]] * batch_size) + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + else: + valid_mask_channel_sizes = [1, 3] + # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) + if mask.shape[3] in valid_mask_channel_sizes: + mask = mask.permute(0, 3, 1, 2) + elif mask.shape[1] not in valid_mask_channel_sizes: + raise ValueError( + f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," + f" but received mask of shape {tuple(mask.shape)}" + ) + # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape + mask = mask.mean(dim=1, keepdim=True) + h, w = mask.shape[-2:] + h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 + mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) + return mask + + +class StableDiffusionInpaintPipelineLegacy( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + deprecation_message = ( + f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" + "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" + "for more information." + ) + deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): + image = image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = self.vae.config.scaling_factor * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # add noise to latents using the timesteps + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + add_predicted_noise: Optional[bool] = False, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the + expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to + that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess image and mask + if not isinstance(image, torch.Tensor): + image = preprocess_image(image, batch_size) + + mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + # encode the init image into latents and scale the latents + latents, init_latents_orig, noise = self.prepare_latents( + image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 7. Prepare mask latent + mask = mask_image.to(device=device, dtype=latents.dtype) + mask = torch.cat([mask] * num_images_per_prompt) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # masking + if add_predicted_noise: + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise_pred_uncond, torch.tensor([t]) + ) + else: + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # use original latents corresponding to unmasked portions of the image + latents = (init_latents_orig * mask) + (latents * (1 - mask)) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py new file mode 100755 index 0000000..9e91986 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py @@ -0,0 +1,830 @@ +# Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import PNDMScheduler +from ....schedulers.scheduling_utils import SchedulerMixin +from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] + + +class StableDiffusionModelEditingPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + Pipeline for text-to-image model editing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + with_to_k ([`bool`]): + Whether to edit the key projection matrices along with the value projection matrices. + with_augs ([`list`]): + Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + with_to_k: bool = True, + with_augs: list = AUGS_CONST, + ): + super().__init__() + + if isinstance(scheduler, PNDMScheduler): + logger.error("PNDMScheduler for this pipeline is currently not supported.") + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.with_to_k = with_to_k + self.with_augs = with_augs + + # get cross-attention layers + ca_layers = [] + + def append_ca(net_): + if net_.__class__.__name__ == "CrossAttention": + ca_layers.append(net_) + elif hasattr(net_, "children"): + for net__ in net_.children(): + append_ca(net__) + + # recursively find all cross-attention layers in unet + for net in self.unet.named_children(): + if "down" in net[0]: + append_ca(net[1]) + elif "up" in net[0]: + append_ca(net[1]) + elif "mid" in net[0]: + append_ca(net[1]) + + # get projection matrices + self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] + self.projection_matrices = [l.to_v for l in self.ca_clip_layers] + self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] + if self.with_to_k: + self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] + self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def edit_model( + self, + source_prompt: str, + destination_prompt: str, + lamb: float = 0.1, + restart_params: bool = True, + ): + r""" + Apply model editing via closed-form solution (see Eq. 5 in the TIME [paper](https://arxiv.org/abs/2303.08084)). + + Args: + source_prompt (`str`): + The source prompt containing the concept to be edited. + destination_prompt (`str`): + The destination prompt. Must contain all words from `source_prompt` with additional ones to specify the + target edit. + lamb (`float`, *optional*, defaults to 0.1): + The lambda parameter specifying the regularization intesity. Smaller values increase the editing power. + restart_params (`bool`, *optional*, defaults to True): + Restart the model parameters to their pre-trained version before editing. This is done to avoid edit + compounding. When it is `False`, edits accumulate. + """ + + # restart LDM parameters + if restart_params: + num_ca_clip_layers = len(self.ca_clip_layers) + for idx_, l in enumerate(self.ca_clip_layers): + l.to_v = copy.deepcopy(self.og_matrices[idx_]) + self.projection_matrices[idx_] = l.to_v + if self.with_to_k: + l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) + self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k + + # set up sentences + old_texts = [source_prompt] + new_texts = [destination_prompt] + # add augmentations + base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] + for aug in self.with_augs: + old_texts.append(aug + base) + base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] + for aug in self.with_augs: + new_texts.append(aug + base) + + # prepare input k* and v* + old_embs, new_embs = [], [] + for old_text, new_text in zip(old_texts, new_texts): + text_input = self.tokenizer( + [old_text, new_text], + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + old_emb, new_emb = text_embeddings + old_embs.append(old_emb) + new_embs.append(new_emb) + + # identify corresponding destinations for each token in old_emb + idxs_replaces = [] + for old_text, new_text in zip(old_texts, new_texts): + tokens_a = self.tokenizer(old_text).input_ids + tokens_b = self.tokenizer(new_text).input_ids + tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] + tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] + num_orig_tokens = len(tokens_a) + idxs_replace = [] + j = 0 + for i in range(num_orig_tokens): + curr_token = tokens_a[i] + while tokens_b[j] != curr_token: + j += 1 + idxs_replace.append(j) + j += 1 + while j < 77: + idxs_replace.append(j) + j += 1 + while len(idxs_replace) < 77: + idxs_replace.append(76) + idxs_replaces.append(idxs_replace) + + # prepare batch: for each pair of setences, old context and new values + contexts, valuess = [], [] + for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): + context = old_emb.detach() + values = [] + with torch.no_grad(): + for layer in self.projection_matrices: + values.append(layer(new_emb[idxs_replace]).detach()) + contexts.append(context) + valuess.append(values) + + # edit the model + for layer_num in range(len(self.projection_matrices)): + # mat1 = \lambda W + \sum{v k^T} + mat1 = lamb * self.projection_matrices[layer_num].weight + + # mat2 = \lambda I + \sum{k k^T} + mat2 = lamb * torch.eye( + self.projection_matrices[layer_num].weight.shape[1], + device=self.projection_matrices[layer_num].weight.device, + ) + + # aggregate sums for mat1, mat2 + for context, values in zip(contexts, valuess): + context_vector = context.reshape(context.shape[0], context.shape[1], 1) + context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) + value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) + for_mat1 = (value_vector @ context_vector_T).sum(dim=0) + for_mat2 = (context_vector @ context_vector_T).sum(dim=0) + mat1 += for_mat1 + mat2 += for_mat2 + + # update projection matrix + self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + ```py + >>> import torch + >>> from diffusers import StableDiffusionModelEditingPipeline + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) + + >>> pipe = pipe.to("cuda") + + >>> source_prompt = "A pack of roses" + >>> destination_prompt = "A pack of blue roses" + >>> pipe.edit_model(source_prompt, destination_prompt) + + >>> prompt = "A field of roses" + >>> image = pipe(prompt).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py new file mode 100755 index 0000000..be21900 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py @@ -0,0 +1,796 @@ +# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DDPMParallelScheduler + >>> from diffusers import StableDiffusionParadigmsPipeline + + >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") + + >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> ngpu, batch_per_device = torch.cuda.device_count(), 5 + >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)]) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0] + ``` +""" + + +class StableDiffusionParadigmsPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using a parallelized version of Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs + self.wrapped_unet = self.unet + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _cumsum(self, input, dim, debug=False): + if debug: + # cumsum_cuda_kernel does not have a deterministic implementation + # so perform cumsum on cpu for debugging purposes + return torch.cumsum(input.cpu().float(), dim=dim).to(input.device) + else: + return torch.cumsum(input, dim=dim) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + parallel: int = 10, + tolerance: float = 0.1, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + debug: bool = False, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + parallel (`int`, *optional*, defaults to 10): + The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but + requires higher memory usage and can also require more total FLOPs. + tolerance (`float`, *optional*, defaults to 0.1): + The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower + tolerance usually leads to less or no degradation. Higher tolerance is faster but can risk degradation + of sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + debug (`bool`, *optional*, defaults to `False`): + Whether or not to run in debug mode. In debug mode, `torch.cumsum` is evaluated using the CPU. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + extra_step_kwargs.pop("generator", None) + + # # 7. Denoising loop + scheduler = self.scheduler + parallel = min(parallel, len(scheduler.timesteps)) + + begin_idx = 0 + end_idx = parallel + latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1)) + + # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep. + # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop. + noise_array = torch.zeros_like(latents_time_evolution_buffer) + for j in range(len(scheduler.timesteps)): + base_noise = randn_tensor( + shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype + ) + noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise + noise_array[j] = noise.clone() + + # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance + # outside of the denoising loop to avoid recomputing it at every step. + # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step. + inverse_variance_norm = 1.0 / torch.tensor( + [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0] + ).to(noise_array.device) + latent_dim = noise_array[0, 0].numel() + inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim + + scaled_tolerance = tolerance**2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + steps = 0 + while begin_idx < len(scheduler.timesteps): + # these have shape (parallel_dim, 2*batch_size, ...) + # parallel_len is at most parallel, but could be less if we are at the end of the timesteps + # we are processing batch window of timesteps spanning [begin_idx, end_idx) + parallel_len = end_idx - begin_idx + + block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len) + block_latents = latents_time_evolution_buffer[begin_idx:end_idx] + block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt) + t_vec = block_t + if do_classifier_free_guidance: + t_vec = t_vec.repeat(1, 2) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec) + + # if parallel_len is small, no need to use multiple GPUs + net = self.wrapped_unet if parallel_len > 3 else self.unet + # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...] + model_output = net( + latent_model_input.flatten(0, 1), + t_vec.flatten(0, 1), + encoder_hidden_states=block_prompt_embeds.flatten(0, 1), + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + per_latent_shape = model_output.shape[1:] + if do_classifier_free_guidance: + model_output = model_output.reshape( + parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape + ) + noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1] + model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + model_output = model_output.reshape( + parallel_len * batch_size * num_images_per_prompt, *per_latent_shape + ) + + block_latents_denoise = scheduler.batch_step_no_noise( + model_output=model_output, + timesteps=block_t.flatten(0, 1), + sample=block_latents.flatten(0, 1), + **extra_step_kwargs, + ).reshape(block_latents.shape) + + # back to shape (parallel_dim, batch_size, ...) + # now we want to add the pre-sampled noise + # parallel sampling algorithm requires computing the cumulative drift from the beginning + # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises. + delta = block_latents_denoise - block_latents + cumulative_delta = self._cumsum(delta, dim=0, debug=debug) + cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug) + + # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise + if scheduler._is_ode_scheduler: + cumulative_noise = 0 + + block_latents_new = ( + latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise + ) + cur_error = torch.linalg.norm( + (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape( + parallel_len, batch_size * num_images_per_prompt, -1 + ), + dim=-1, + ).pow(2) + error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1] + + # find the first index of the vector error_ratio that is greater than error tolerance + # we can shift the window for the next iteration up to this index + error_ratio = torch.nn.functional.pad( + error_ratio, (0, 0, 0, 1), value=1e9 + ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension + any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int() + ind = torch.argmax(any_error_at_time).item() + + # compute the new begin and end idxs for the window + new_begin_idx = begin_idx + min(1 + ind, parallel) + new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps)) + + # store the computed latents for the current window in the global buffer + latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new + # initialize the new sliding window latents with the end of the current window, + # should be better than random initialization + latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][ + None, + ] + + steps += 1 + + progress_bar.update(new_begin_idx - begin_idx) + if callback is not None and steps % callback_steps == 0: + callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx]) + + begin_idx = new_begin_idx + end_idx = new_end_idx + + latents = latents_time_evolution_buffer[-1] + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py new file mode 100755 index 0000000..2978972 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py @@ -0,0 +1,1310 @@ +# Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + BlipForConditionalGeneration, + BlipProcessor, + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, +) + +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.attention_processor import Attention +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler +from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler +from ....utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.Tensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + latents: torch.Tensor + images: Union[List[PIL.Image.Image], np.ndarray] + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + + >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline + + + >>> def download(embedding_url, local_filepath): + ... r = requests.get(embedding_url) + ... with open(local_filepath, "wb") as f: + ... f.write(r.content) + + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.to("cuda") + + >>> prompt = "a high resolution painting of a cat in the style of van gough" + >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" + >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" + + >>> for url in [source_emb_url, target_emb_url]: + ... download(url, url.split("/")[-1]) + + >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) + >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) + >>> images = pipeline( + ... prompt, + ... source_embeds=src_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... ).images + + >>> images[0].save("edited_image_dog.png") + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import BlipForConditionalGeneration, BlipProcessor + >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline + + >>> import requests + >>> from PIL import Image + + >>> captioner_id = "Salesforce/blip-image-captioning-base" + >>> processor = BlipProcessor.from_pretrained(captioner_id) + >>> model = BlipForConditionalGeneration.from_pretrained( + ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ... ) + + >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + ... sd_model_ckpt, + ... caption_generator=model, + ... caption_processor=processor, + ... torch_dtype=torch.float16, + ... safety_checker=None, + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" + + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) + >>> # generate caption + >>> caption = pipeline.generate_caption(raw_image) + + >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" + >>> inv_latents = pipeline.invert(caption, image=raw_image).latents + >>> # we need to generate source and target embeds + + >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] + + >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + + >>> source_embeds = pipeline.get_embeds(source_prompts) + >>> target_embeds = pipeline.get_embeds(target_prompts) + >>> # the latents can then be used to edit a real image + >>> # when using Stable Diffusion 2 or other models that use v-prediction + >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion + + >>> image = pipeline( + ... caption, + ... source_embeds=source_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... generator=generator, + ... latents=inv_latents, + ... negative_prompt=caption, + ... ).images[0] + >>> image.save("edited_image.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def prepare_unet(unet: UNet2DConditionModel): + """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" + pix2pix_zero_attn_procs = {} + for name in unet.attn_processors.keys(): + module_name = name.replace(".processor", "") + module = unet.get_submodule(module_name) + if "attn2" in name: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) + module.requires_grad_(True) + else: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) + module.requires_grad_(False) + + unet.set_attn_processor(pix2pix_zero_attn_procs) + return unet + + +class Pix2PixZeroL2Loss: + def __init__(self): + self.loss = 0.0 + + def compute_loss(self, predictions, targets): + self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) + + +class Pix2PixZeroAttnProcessor: + """An attention processor class to store the attention weights. + In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" + + def __init__(self, is_pix2pix_zero=False): + self.is_pix2pix_zero = is_pix2pix_zero + if self.is_pix2pix_zero: + self.reference_cross_attn_map = {} + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + timestep=None, + loss=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + if self.is_pix2pix_zero and timestep is not None: + # new bookkeeping to save the attention weights. + if loss is None: + self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() + # compute loss + elif loss is not None: + prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) + loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + requires_safety_checker (bool): + Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the + pipeline publicly. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "caption_generator", + "caption_processor", + "inverse_scheduler", + ] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], + feature_extractor: CLIPImageProcessor, + safety_checker: StableDiffusionSafetyChecker, + inverse_scheduler: DDIMInverseScheduler, + caption_generator: BlipForConditionalGeneration, + caption_processor: BlipProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + caption_processor=caption_processor, + caption_generator=caption_generator, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if source_embeds is None and target_embeds is None: + raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def generate_caption(self, images): + """Generates caption for a given image.""" + text = "a photography of" + + prev_device = self.caption_generator.device + + device = self._execution_device + inputs = self.caption_processor(images, text, return_tensors="pt").to( + device=device, dtype=self.caption_generator.dtype + ) + self.caption_generator.to(device) + outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) + + # offload caption generator + self.caption_generator.to(prev_device) + + caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] + return caption + + def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): + """Constructs the edit direction to steer the image generation process semantically.""" + return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) + + @torch.no_grad() + def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor: + num_prompts = len(prompt) + embeds = [] + for i in range(0, num_prompts, batch_size): + prompt_slice = prompt[i : i + batch_size] + + input_ids = self.tokenizer( + prompt_slice, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + input_ids = input_ids.to(self.text_encoder.device) + embeds.append(self.text_encoder(input_ids)[0]) + + return torch.cat(embeds, dim=0).mean(0)[None] + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + def auto_corr_loss(self, hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + return reg_loss + + def kl_divergence(self, hidden_states): + mean = hidden_states.mean() + var = hidden_states.var() + return var + mean**2 - 1 - torch.log(var + 1e-7) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + source_embeds: torch.Tensor = None, + target_embeds: torch.Tensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + cross_attention_guidance_amount: float = 0.1, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + source_embeds (`torch.Tensor`): + Source concept embeddings. Generation of the embeddings as per the [original + paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. + target_embeds (`torch.Tensor`): + Target concept embeddings. Generation of the embeddings as per the [original + paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Define the spatial resolutions. + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds, + ) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Generate the inverted noise from the input image or any other image + # generated from the input prompt. + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents_init = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Compute the edit directions. + edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) + + # 9. Edit the prompt embeddings as per the edit directions discovered. + prompt_embeds_edit = prompt_embeds.clone() + prompt_embeds_edit[1:2] += edit_direction + + # 10. Second denoising loop to generate the edited image. + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + latents = latents_init + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # we want to learn the latent such that it steers the generation + # process towards the edited direction, so make the make initial + # noise learnable + x_in = latent_model_input.detach().clone() + x_in.requires_grad = True + + # optimizer + opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) + + with torch.enable_grad(): + # initialize loss + loss = Pix2PixZeroL2Loss() + + # predict the noise residual + noise_pred = self.unet( + x_in, + t, + encoder_hidden_states=prompt_embeds_edit.detach(), + cross_attention_kwargs={"timestep": t, "loss": loss}, + ).sample + + loss.loss.backward(retain_graph=False) + opt.step() + + # recompute the noise + noise_pred = self.unet( + x_in.detach(), + t, + encoder_hidden_states=prompt_embeds_edit, + cross_attention_kwargs={"timestep": None}, + ).sample + + latents = x_in.detach().chunk(2)[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: Optional[str] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 50, + guidance_scale: float = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + cross_attention_guidance_amount: float = 0.1, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 5, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be used for conditioning. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 5): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or + `tuple`: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted + latents tensor and then second is the corresponding decoded image. + """ + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) + + # 5. Encode input prompt + num_images_per_prompt = 1 + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.inverse_scheduler.timesteps + + # 6. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = self.kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + inverted_latents = latents.detach().clone() + + # 8. Post-processing + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (inverted_latents, image) + + return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py new file mode 100755 index 0000000..15c9a8c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_stochastic_karras_ve import KarrasVePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/diffusers/src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py new file mode 100755 index 0000000..023edb4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch + +from ....models import UNet2DModel +from ....schedulers import KarrasVeScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class KarrasVePipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`KarrasVeScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. + """ + + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + # sample x_0 ~ N(0, sigma_0^2 * I) + sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # here sigma_t == t_i from the paper + sigma = self.scheduler.schedule[t] + sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 + + # 1. Select temporarily increased noise level sigma_hat + # 2. Add new noise to move from sample_i to sample_hat + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) + + # 3. Predict the noise residual given the noise magnitude `sigma_hat` + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample + + # 4. Evaluate dx/dt at sigma_hat + # 5. Take Euler step from sigma to sigma_prev + step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) + + if sigma_prev != 0: + # 6. Apply 2nd order correction + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample + step_output = self.scheduler.step_correct( + model_output, + sigma_hat, + sigma_prev, + sample_hat, + step_output.prev_sample, + step_output["derivative"], + ) + sample = step_output.prev_sample + + sample = (sample / 2 + 0.5).clamp(0, 1) + image = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py new file mode 100755 index 0000000..8ea6ef6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py @@ -0,0 +1,71 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + + _dummy_objects.update( + { + "VersatileDiffusionDualGuidedPipeline": VersatileDiffusionDualGuidedPipeline, + "VersatileDiffusionImageVariationPipeline": VersatileDiffusionImageVariationPipeline, + "VersatileDiffusionPipeline": VersatileDiffusionPipeline, + "VersatileDiffusionTextToImagePipeline": VersatileDiffusionTextToImagePipeline, + } + ) +else: + _import_structure["modeling_text_unet"] = ["UNetFlatConditionModel"] + _import_structure["pipeline_versatile_diffusion"] = ["VersatileDiffusionPipeline"] + _import_structure["pipeline_versatile_diffusion_dual_guided"] = ["VersatileDiffusionDualGuidedPipeline"] + _import_structure["pipeline_versatile_diffusion_image_variation"] = ["VersatileDiffusionImageVariationPipeline"] + _import_structure["pipeline_versatile_diffusion_text_to_image"] = ["VersatileDiffusionTextToImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + else: + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py new file mode 100755 index 0000000..3937e87 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,2518 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.utils import deprecate + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin +from ....models.activations import get_activation +from ....models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, +) +from ....models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from ....models.resnet import ResnetBlockCondNorm2D +from ....models.transformers.dual_transformer_2d import DualTransformer2DModel +from ....models.transformers.transformer_2d import Transformer2DModel +from ....models.unets.unet_2d_condition import UNet2DConditionOutput +from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import apply_freeu + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + transformer_layers_per_block, + attention_type, + attention_head_dim, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + dropout=0.0, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockFlat": + return DownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} is not supported.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + transformer_layers_per_block, + resolution_idx, + attention_type, + attention_head_dim, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + dropout=0.0, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockFlat": + return UpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +class FourierEmbedder(nn.Module): + def __init__(self, num_freqs=64, temperature=100): + super().__init__() + + self.num_freqs = num_freqs + self.temperature = temperature + + freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) + freq_bands = freq_bands[None, None, None] + self.register_buffer("freq_bands", freq_bands, persistent=False) + + def __call__(self, x): + x = self.freq_bands * x.unsqueeze(-1) + return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) + + +class GLIGENTextBoundingboxProjection(nn.Module): + def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8): + super().__init__() + self.positive_len = positive_len + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy + + if isinstance(out_dim, tuple): + out_dim = out_dim[0] + + if feature_type == "text-only": + self.linears = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + elif feature_type == "text-image": + self.linears_text = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.linears_image = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + def forward( + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, + ): + masks = masks.unsqueeze(-1) + + xyxy_embedding = self.fourier_embedder(boxes) + xyxy_null = self.null_position_feature.view(1, 1, -1) + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + if positive_embeddings: + positive_null = self.null_positive_feature.view(1, 1, -1) + positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null + + objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + else: + phrases_masks = phrases_masks.unsqueeze(-1) + image_masks = image_masks.unsqueeze(-1) + + text_null = self.null_text_feature.view(1, 1, -1) + image_null = self.null_image_feature.view(1, 1, -1) + + phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null + image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null + + objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) + objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) + objs = torch.cat([objs_text, objs_image], dim=1) + + return objs + + +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or + `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"] + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = LinearMultiDim( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlockFlatCrossAttn": + self.mid_block = UNetMidBlockFlatCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": + self.mid_block = UNetMidBlockFlatSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlockFlat": + self.mid_block = UNetMidBlockFlat( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = LinearMultiDim( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNetFlatConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlockFlat + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + return output_tensor + + +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + additional_residuals: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlat(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100755 index 0000000..c8dc18e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,421 @@ +import inspect +from typing import Callable, List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import logging +from ...pipeline_utils import DiffusionPipeline +from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + text_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + @torch.no_grad() + def image_variation( + self, + image: Union[torch.Tensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.image_variation(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + return VersatileDiffusionImageVariationPipeline(**components)( + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text_to_image( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) + output = temp_pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + # swap the attention blocks back to the original state + temp_pipeline._swap_unet_attention_blocks() + + return output + + @torch.no_grad() + def dual_guided( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe.dual_guided( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) + output = temp_pipeline( + prompt=prompt, + image=image, + text_to_image_strength=text_to_image_strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + temp_pipeline._revert_dual_attention() + + return output diff --git a/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py new file mode 100755 index 0000000..2212651 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -0,0 +1,561 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.utils.checkpoint +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): + r""" + Pipeline for image-text dual-guided generation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None and ( + "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention + ): + # if loading from a universal checkpoint rather than a saved dual-guided pipeline + self._convert_to_dual_attention() + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _convert_to_dual_attention(self): + """ + Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks + from both `image_unet` and `text_unet` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + + image_transformer = self.image_unet.get_submodule(parent_name)[index] + text_transformer = self.text_unet.get_submodule(parent_name)[index] + + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) + dual_transformer.transformers[0] = image_transformer + dual_transformer.transformers[1] = text_transformer + + self.image_unet.get_submodule(parent_name)[index] = dual_transformer + self.image_unet.register_to_config(dual_cross_attention=True) + + def _revert_dual_attention(self): + """ + Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call + this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] + + self.image_unet.register_to_config(dual_cross_attention=False) + + def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, image, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") + if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): + raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + + for i, type in enumerate(condition_types): + if type == "text": + module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings + module.transformer_index_for_condition[i] = 1 # use the second (text) transformer + else: + module.condition_lengths[i] = 257 + module.transformer_index_for_condition[i] = 0 # use the first (image) transformer + + @torch.no_grad() + def __call__( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionDualGuidedPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, callback_steps) + + # 2. Define call parameters + prompt = [prompt] if not isinstance(prompt, list) else prompt + image = [image] if not isinstance(image, list) else image + batch_size = len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) + dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) + prompt_types = ("text", "image") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dual_prompt_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Combine the attention blocks of the image and text UNets + self.set_transformer_params(text_to_image_strength, prompt_types) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py new file mode 100755 index 0000000..62d3e83 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -0,0 +1,402 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.utils.checkpoint +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + Pipeline for image variation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + image_feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + image_feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: + prompt = list(prompt) + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: List[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionImageVariationPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py new file mode 100755 index 0000000..de4c2ac --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -0,0 +1,480 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +import torch.utils.checkpoint +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None: + self._swap_unet_attention_blocks() + + def _swap_unet_attention_blocks(self): + """ + Swap the `Transformer2DModel` blocks between the image and text UNets + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionTextToImagePipeline + >>> import torch + + >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/deprecated/vq_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/deprecated/vq_diffusion/__init__.py new file mode 100755 index 0000000..0709033 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/vq_diffusion/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + LearnedClassifierFreeSamplingEmbeddings, + VQDiffusionPipeline, + ) + + _dummy_objects.update( + { + "LearnedClassifierFreeSamplingEmbeddings": LearnedClassifierFreeSamplingEmbeddings, + "VQDiffusionPipeline": VQDiffusionPipeline, + } + ) +else: + _import_structure["pipeline_vq_diffusion"] = ["LearnedClassifierFreeSamplingEmbeddings", "VQDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + LearnedClassifierFreeSamplingEmbeddings, + VQDiffusionPipeline, + ) + else: + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py b/diffusers/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py new file mode 100755 index 0000000..8dee000 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,325 @@ +# Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin, Transformer2DModel, VQModel +from ....schedulers import VQDiffusionScheduler +from ....utils import logging +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + +class VQDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using VQ Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vqvae ([`VQModel`]): + Vector Quantized Variational Auto-Encoder (VAE) model to encode and decode images to and from latent + representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + transformer ([`Transformer2DModel`]): + A conditional `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`VQDiffusionScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + vqvae: VQModel + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings + scheduler: VQDiffusionScheduler + + def __init__( + self, + vqvae: VQModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: Transformer2DModel, + scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_inference_steps: int = 100, + guidance_scale: float = 5.0, + truncation_rate: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): + Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at + most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above + `truncation_rate` are set to zero. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor` of shape (batch), *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Must be valid embedding indices.If not provided, a latents tensor will be generated of + completely masked latent pixels. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get the initial completely masked latents unless the user supplied it + + latents_shape = (batch_size, self.transformer.num_latent_pixels) + if latents is None: + mask_class = self.transformer.num_vector_embeds - 1 + latents = torch.full(latents_shape, mask_class).to(self.device) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): + raise ValueError( + "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," + f" {self.transformer.num_vector_embeds - 1} (inclusive)." + ) + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + sample = latents + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + + # predict the un-noised image + # model_output == `log_p_x_0` + model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) + + model_output = self.truncate(model_output, truncation_rate) + + # remove `log(0)`'s (`-inf`s) + model_output = model_output.clamp(-70) + + # compute the previous noisy sample x_t -> x_t-1 + sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + embedding_channels = self.vqvae.config.vq_embed_dim + embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) + embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) + image = self.vqvae.decode(embeddings, force_not_quantize=True).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def truncate(self, log_p_x_0: torch.Tensor, truncation_rate: float) -> torch.Tensor: + """ + Truncates `log_p_x_0` such that for each column vector, the total cumulative probability is `truncation_rate` + The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to + zero. + """ + sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) + sorted_p_x_0 = torch.exp(sorted_log_p_x_0) + keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate + + # Ensure that at least the largest probability is not zeroed out + all_true = torch.full_like(keep_mask[:, 0:1, :], True) + keep_mask = torch.cat((all_true, keep_mask), dim=1) + keep_mask = keep_mask[:, :-1, :] + + keep_mask = keep_mask.gather(1, indices.argsort(1)) + + rv = log_p_x_0.clone() + + rv[~keep_mask] = -torch.inf # -inf = log(0) + + return rv diff --git a/diffusers/src/diffusers/pipelines/dit/__init__.py b/diffusers/src/diffusers/pipelines/dit/__init__.py new file mode 100755 index 0000000..fe2a94f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/dit/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_dit": ["DiTPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_dit import DiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/src/diffusers/pipelines/dit/pipeline_dit.py b/diffusers/src/diffusers/pipelines/dit/pipeline_dit.py new file mode 100755 index 0000000..14321b5 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/dit/pipeline_dit.py @@ -0,0 +1,236 @@ +# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +# William Peebles and Saining Xie +# +# Copyright (c) 2021 OpenAI +# MIT License +# +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from ...models import AutoencoderKL, DiTTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation based on a Transformer backbone instead of a UNet. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + transformer ([`DiTTransformer2DModel`]): + A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "transformer->vae" + + def __init__( + self, + transformer: DiTTransformer2DModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + id2label: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + + # create a imagenet -> id dictionary for easier use + self.labels = {} + if id2label is not None: + for key, value in id2label.items(): + for label in value.split(","): + self.labels[label.lstrip().rstrip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + + Map label strings from ImageNet to corresponding class ids. + + Parameters: + label (`str` or `dict` of `str`): + Label strings to be mapped to class ids. + + Returns: + `list` of `int`: + Class ids to be processed by pipeline. + """ + + if not isinstance(label, list): + label = list(label) + + for l in label: + if l not in self.labels: + raise ValueError( + f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[l] for l in label] + + @torch.no_grad() + def __call__( + self, + class_labels: List[int], + guidance_scale: float = 4.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + class_labels (List[int]): + List of ImageNet class labels for the images to be generated. + guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 250): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + ```py + >>> from diffusers import DiTPipeline, DPMSolverMultistepScheduler + >>> import torch + + >>> pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe = pipe.to("cuda") + + >>> # pick words from Imagenet class labels + >>> pipe.labels # to print all available words + + >>> # pick words that exist in ImageNet + >>> words = ["white shark", "umbrella"] + + >>> class_ids = pipe.get_label_ids(words) + + >>> generator = torch.manual_seed(33) + >>> output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator) + + >>> image = output.images[0] # label 'white shark' + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + batch_size = len(class_labels) + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + + latents = randn_tensor( + shape=(batch_size, latent_channels, latent_size, latent_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1) + class_null = torch.tensor([1000] * batch_size, device=self._execution_device) + class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1: + half = latent_model_input[: len(latent_model_input) // 2] + latent_model_input = torch.cat([half, half], dim=0) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + timesteps = t + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(latent_model_input.shape[0]) + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, timestep=timesteps, class_labels=class_labels_input + ).sample + + # perform guidance + if guidance_scale > 1: + eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + + noise_pred = torch.cat([eps, rest], dim=1) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + model_output, _ = torch.split(noise_pred, latent_channels, dim=1) + else: + model_output = noise_pred + + # compute previous image: x_t -> x_t-1 + latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + + if guidance_scale > 1: + latents, _ = latent_model_input.chunk(2, dim=0) + else: + latents = latent_model_input + + latents = 1 / self.vae.config.scaling_factor * latents + samples = self.vae.decode(latents).sample + + samples = (samples / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + samples = self.numpy_to_pil(samples) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (samples,) + + return ImagePipelineOutput(images=samples) diff --git a/diffusers/src/diffusers/pipelines/flux/__init__.py b/diffusers/src/diffusers/pipelines/flux/__init__.py new file mode 100755 index 0000000..a19019a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/__init__.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["FluxPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_flux"] = ["FluxPipeline"] + _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] + _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] + _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] + _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] + _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] + _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flux import FluxPipeline + from .pipeline_flux_controlnet import FluxControlNetPipeline + from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline + from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline + from .pipeline_flux_fill import FluxFillPipeline + from .pipeline_flux_img2img import FluxImg2ImgPipeline + from .pipeline_flux_inpaint import FluxInpaintPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/flux/modeling_flux.py b/diffusers/src/diffusers/pipelines/flux/modeling_flux.py new file mode 100644 index 0000000..5ff60f7 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/modeling_flux.py @@ -0,0 +1,47 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput + + +@dataclass +class ReduxImageEncoderOutput(BaseOutput): + image_embeds: Optional[torch.Tensor] = None + + +class ReduxImageEncoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + ) -> None: + super().__init__() + + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) + self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) + + def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput: + projected_x = self.redux_down(nn.functional.silu(self.redux_up(x))) + + return ReduxImageEncoderOutput(image_embeds=projected_x) diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py b/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py new file mode 100755 index 0000000..bb21488 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,771 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py new file mode 100755 index 0000000..11b71b1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -0,0 +1,924 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image + >>> from diffusers import FluxControlNetPipeline + >>> from diffusers import FluxControlNetModel + + >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" + >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) + >>> pipe = FluxControlNetPipeline.from_pretrained( + ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> image = pipe( + ... prompt, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.6, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + control_image: PipelineImageInput = None, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_mode (`int` or `List[int]`,, *optional*, defaults to None): + The control mode when applying ControlNet-Union. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # set control mode + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + # set control mode + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + # controlnet + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=controlnet_conditioning_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py new file mode 100755 index 0000000..deeb9e3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -0,0 +1,949 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetModel + >>> from diffusers.utils import load_image + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 + ... ) + + >>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.text_encoder.to(torch.float16) + >>> pipe.controlnet.to(torch.float16) + >>> pipe.to("cuda") + + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + ... ) + + >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.6, + ... strength=0.7, + ... num_inference_steps=2, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_img2img.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for image-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + pooled_prompt_embeds=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to modify with the pipeline. + control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_mode (`int` or `List[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + num_channels_latents = self.transformer.config.in_channels // 4 + + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=controlnet_conditioning_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py new file mode 100755 index 0000000..e763200 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -0,0 +1,1139 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetInpaintPipeline + >>> from diffusers.models import FluxControlNetModel + >>> from diffusers.utils import load_image + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> control_image = load_image( + ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ... ) + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + ... ) + >>> mask_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + ... ) + + >>> prompt = "A girl holding a sign that says InstantX" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.7, + ... strength=0.7, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], + ): + super().__init__() + + self.register_modules( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + controlnet=controlnet, + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + padding_mask_crop: Optional[int] = None, + timesteps: List[int] = None, + num_inference_steps: int = 28, + guidance_scale: float = 7.0, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to inpaint. + mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels + will be preserved. + masked_image_latents (`torch.FloatTensor`, *optional*): + Pre-generated masked image latents. + control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. + padding_mask_crop (`int`, *optional*): + The size of the padding to use when cropping the mask. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_mode (`int` or `List[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + global_height = height + global_width = width + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + # 3. Encode input prompt + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=height, + height=width, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # set control mode + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + # set control mode + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + # 6. Prepare timesteps + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare mask latents + mask_condition = self.mask_processor.preprocess( + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # predict the noise residual + if self.controlnet.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=controlnet_conditioning_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # For inpainting, we need to apply the mask and add the masked image latents + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Post-processing + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_fill.py new file mode 100644 index 0000000..a0c1c47 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -0,0 +1,969 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxFillPipeline + >>> from diffusers.utils import load_image + + >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png") + >>> mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png") + + >>> pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU + + >>> image = pipe( + ... prompt="a white paper cup", + ... image=image, + ... mask_image=mask, + ... height=1632, + ... width=1232, + ... guidance_scale=30, + ... num_inference_steps=50, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0] + >>> image.save("flux_fill.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxFillPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux Fill pipeline for image inpainting/outpainting. + + Reference: https://blackforestlabs.ai/flux-1-tools/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # 1. calculate the height and width of the latents + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + # 2. encode the masked image + if masked_image.shape[1] == num_channels_latents: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + batch_size = batch_size * num_images_per_prompt + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # 4. pack the masked_image_latents + # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4 + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + + # 5.resize mask to latents shape we we concatenate the mask to the latents + mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed) + mask = mask.view( + batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor + ) # batch_size, height, 8, width, 8 + mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width + mask = mask.reshape( + batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width + ) # batch_size, 8*8, height, width + + # 6. pack the mask: + # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2 + mask = self._pack_latents( + mask, + batch_size, + self.vae_scale_factor * self.vae_scale_factor, + height, + width, + ) + mask = mask.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + image=None, + mask_image=None, + masked_image_latents=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + if image is not None and masked_image_latents is not None: + raise ValueError( + "Please provide either `image` or `masked_image_latents`, `masked_image_latents` should not be passed." + ) + + if image is not None and mask_image is None: + raise ValueError("Please provide `mask_image` when passing `image`.") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: Optional[torch.FloatTensor] = None, + mask_image: Optional[torch.FloatTensor] = None, + masked_image_latents: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 30.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + image=image, + mask_image=mask_image, + masked_image_latents=masked_image_latents, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare prompt embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare mask and masked image latents + if masked_image_latents is not None: + masked_image_latents = masked_image_latents.to(latents.device) + else: + image = self.image_processor.preprocess(image, height=height, width=width) + mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) + + masked_image = image * (1 - mask_image) + masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) + + height, width = image.shape[-2:] + mask, masked_image_latents = self.prepare_mask_latents( + mask_image, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=torch.cat((latents, masked_image_latents), dim=2), + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-process the image + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_img2img.py new file mode 100755 index 0000000..bee4f6c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -0,0 +1,844 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import FluxImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe( + ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py new file mode 100755 index 0000000..4603367 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -0,0 +1,1009 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("flux_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/flux/pipeline_output.py b/diffusers/src/diffusers/pipelines/flux/pipeline_output.py new file mode 100755 index 0000000..c3a5a15 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/flux/pipeline_output.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class FluxPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class FluxPriorReduxPipelineOutput(BaseOutput): + """ + Output class for Flux Prior Redux pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + prompt_embeds: torch.Tensor + pooled_prompt_embeds: torch.Tensor \ No newline at end of file diff --git a/diffusers/src/diffusers/pipelines/free_init_utils.py b/diffusers/src/diffusers/pipelines/free_init_utils.py new file mode 100755 index 0000000..1fb6759 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/free_init_utils.py @@ -0,0 +1,187 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple, Union + +import torch +import torch.fft as fft + +from ..utils.torch_utils import randn_tensor + + +class FreeInitMixin: + r"""Mixin class for FreeInit.""" + + def enable_free_init( + self, + num_iters: int = 3, + use_fast_sampling: bool = False, + method: str = "butterworth", + order: int = 4, + spatial_stop_frequency: float = 0.25, + temporal_stop_frequency: float = 0.25, + ): + """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. + + This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). + + Args: + num_iters (`int`, *optional*, defaults to `3`): + Number of FreeInit noise re-initialization iterations. + use_fast_sampling (`bool`, *optional*, defaults to `False`): + Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables the + "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. + method (`str`, *optional*, defaults to `butterworth`): + Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the FreeInit low + pass filter. + order (`int`, *optional*, defaults to `4`): + Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour + whereas lower values lead to `gaussian` method behaviour. + spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in the + original implementation. + temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in the + original implementation. + """ + self._free_init_num_iters = num_iters + self._free_init_use_fast_sampling = use_fast_sampling + self._free_init_method = method + self._free_init_order = order + self._free_init_spatial_stop_frequency = spatial_stop_frequency + self._free_init_temporal_stop_frequency = temporal_stop_frequency + + def disable_free_init(self): + """Disables the FreeInit mechanism if enabled.""" + self._free_init_num_iters = None + + @property + def free_init_enabled(self): + return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None + + def _get_free_init_freq_filter( + self, + shape: Tuple[int, ...], + device: Union[str, torch.dtype], + filter_type: str, + order: float, + spatial_stop_frequency: float, + temporal_stop_frequency: float, + ) -> torch.Tensor: + r"""Returns the FreeInit filter based on filter type and other input conditions.""" + + time, height, width = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + + if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: + return mask + + if filter_type == "butterworth": + + def retrieve_mask(x): + return 1 / (1 + (x / spatial_stop_frequency**2) ** order) + elif filter_type == "gaussian": + + def retrieve_mask(x): + return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) + elif filter_type == "ideal": + + def retrieve_mask(x): + return 1 if x <= spatial_stop_frequency * 2 else 0 + else: + raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") + + for t in range(time): + for h in range(height): + for w in range(width): + d_square = ( + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 + + (2 * h / height - 1) ** 2 + + (2 * w / width - 1) ** 2 + ) + mask[..., t, h, w] = retrieve_mask(d_square) + + return mask.to(device) + + def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor: + r"""Noise reinitialization.""" + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + high_pass_filter = 1 - low_pass_filter + x_freq_low = x_freq * low_pass_filter + noise_freq_high = noise_freq * high_pass_filter + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + def _apply_free_init( + self, + latents: torch.Tensor, + free_init_iteration: int, + num_inference_steps: int, + device: torch.device, + dtype: torch.dtype, + generator: torch.Generator, + ): + if free_init_iteration == 0: + self._free_init_initial_noise = latents.detach().clone() + else: + latent_shape = latents.shape + + free_init_filter_shape = (1, *latent_shape[1:]) + free_init_freq_filter = self._get_free_init_freq_filter( + shape=free_init_filter_shape, + device=device, + filter_type=self._free_init_method, + order=self._free_init_order, + spatial_stop_frequency=self._free_init_spatial_stop_frequency, + temporal_stop_frequency=self._free_init_temporal_stop_frequency, + ) + + current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 + diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long() + + z_t = self.scheduler.add_noise( + original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + + z_rand = randn_tensor( + shape=latent_shape, + generator=generator, + device=device, + dtype=torch.float32, + ) + latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter) + latents = latents.to(dtype) + + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + num_inference_steps = max( + 1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1)) + ) + + if num_inference_steps > 0: + self.scheduler.set_timesteps(num_inference_steps, device=device) + + return latents, self.scheduler.timesteps diff --git a/diffusers/src/diffusers/pipelines/free_noise_utils.py b/diffusers/src/diffusers/pipelines/free_noise_utils.py new file mode 100755 index 0000000..dc0071a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/free_noise_utils.py @@ -0,0 +1,596 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock +from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..models.transformers.transformer_2d import Transformer2DModel +from ..models.unets.unet_motion_model import ( + AnimateDiffTransformer3D, + CrossAttnDownBlockMotion, + DownBlockMotion, + UpBlockMotion, +) +from ..pipelines.pipeline_utils import DiffusionPipeline +from ..utils import logging +from ..utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SplitInferenceModule(nn.Module): + r""" + A wrapper module class that splits inputs along a specified dimension before performing a forward pass. + + This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking + them into smaller chunks, processing each chunk separately, and then reassembling the results. + + Args: + module (`nn.Module`): + The underlying PyTorch module that will be applied to each chunk of split inputs. + split_size (`int`, defaults to `1`): + The size of each chunk after splitting the input tensor. + split_dim (`int`, defaults to `0`): + The dimension along which the input tensors are split. + input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`): + A list of keyword arguments (strings) that represent the input tensors to be split. + + Workflow: + 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using + `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. + 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments + that were passed. + 3. The output tensors from each split are concatenated back together along `split_dim` before returning. + + Example: + ```python + >>> import torch + >>> import torch.nn as nn + + >>> model = nn.Linear(1000, 1000) + >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) + + >>> input_tensor = torch.randn(42, 1000) + >>> # Will split the tensor into 21 slices of shape [2, 1000]. + >>> output = split_module(input=input_tensor) + ``` + + It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex + multi-dimensional splitting. + """ + + def __init__( + self, + module: nn.Module, + split_size: int = 1, + split_dim: int = 0, + input_kwargs_to_split: List[str] = ["hidden_states"], + ) -> None: + super().__init__() + + self.module = module + self.split_size = split_size + self.split_dim = split_dim + self.input_kwargs_to_split = set(input_kwargs_to_split) + + def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + r"""Forward method for the `SplitInferenceModule`. + + This method processes the input by splitting specified keyword arguments along a given dimension, running the + underlying module on each split, and then concatenating the results. The splitting is controlled by the + `split_size` and `split_dim` parameters specified during initialization. + + Args: + *args (`Any`): + Positional arguments that are passed directly to the `module` without modification. + **kwargs (`Dict[str, torch.Tensor]`): + Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the + entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword + arguments are passed unchanged. + + Returns: + `Union[torch.Tensor, Tuple[torch.Tensor]]`: + The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred + without it. + - If the underlying module returns a single tensor, the result will be a single concatenated tensor + along the same `split_dim` after processing all splits. + - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated + along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. + """ + split_inputs = {} + + # 1. Split inputs that were specified during initialization and also present in passed kwargs + for key in list(kwargs.keys()): + if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): + continue + split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) + kwargs.pop(key) + + # 2. Invoke forward pass across each split + results = [] + for split_input in zip(*split_inputs.values()): + inputs = dict(zip(split_inputs.keys(), split_input)) + inputs.update(kwargs) + + intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) + results.append(intermediate_tensor_or_tensor_tuple) + + # 3. Concatenate split restuls to obtain final outputs + if isinstance(results[0], torch.Tensor): + return torch.cat(results, dim=self.split_dim) + elif isinstance(results[0], tuple): + return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) + else: + raise ValueError( + "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + ) + + +class AnimateDiffFreeNoiseMixin: + r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" + + def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): + r"""Helper function to enable FreeNoise in transformer blocks.""" + + for motion_module in block.motion_modules: + num_transformer_blocks = len(motion_module.transformer_blocks) + + for i in range(num_transformer_blocks): + if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): + motion_module.transformer_blocks[i].set_free_noise_properties( + self._free_noise_context_length, + self._free_noise_context_stride, + self._free_noise_weighting_scheme, + ) + else: + assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock) + basic_transfomer_block = motion_module.transformer_blocks[i] + + motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock( + dim=basic_transfomer_block.dim, + num_attention_heads=basic_transfomer_block.num_attention_heads, + attention_head_dim=basic_transfomer_block.attention_head_dim, + dropout=basic_transfomer_block.dropout, + cross_attention_dim=basic_transfomer_block.cross_attention_dim, + activation_fn=basic_transfomer_block.activation_fn, + attention_bias=basic_transfomer_block.attention_bias, + only_cross_attention=basic_transfomer_block.only_cross_attention, + double_self_attention=basic_transfomer_block.double_self_attention, + positional_embeddings=basic_transfomer_block.positional_embeddings, + num_positional_embeddings=basic_transfomer_block.num_positional_embeddings, + context_length=self._free_noise_context_length, + context_stride=self._free_noise_context_stride, + weighting_scheme=self._free_noise_weighting_scheme, + ).to(device=self.device, dtype=self.dtype) + + motion_module.transformer_blocks[i].load_state_dict( + basic_transfomer_block.state_dict(), strict=True + ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim + ) + + def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): + r"""Helper function to disable FreeNoise in transformer blocks.""" + + for motion_module in block.motion_modules: + num_transformer_blocks = len(motion_module.transformer_blocks) + + for i in range(num_transformer_blocks): + if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): + free_noise_transfomer_block = motion_module.transformer_blocks[i] + + motion_module.transformer_blocks[i] = BasicTransformerBlock( + dim=free_noise_transfomer_block.dim, + num_attention_heads=free_noise_transfomer_block.num_attention_heads, + attention_head_dim=free_noise_transfomer_block.attention_head_dim, + dropout=free_noise_transfomer_block.dropout, + cross_attention_dim=free_noise_transfomer_block.cross_attention_dim, + activation_fn=free_noise_transfomer_block.activation_fn, + attention_bias=free_noise_transfomer_block.attention_bias, + only_cross_attention=free_noise_transfomer_block.only_cross_attention, + double_self_attention=free_noise_transfomer_block.double_self_attention, + positional_embeddings=free_noise_transfomer_block.positional_embeddings, + num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings, + ).to(device=self.device, dtype=self.dtype) + + motion_module.transformer_blocks[i].load_state_dict( + free_noise_transfomer_block.state_dict(), strict=True + ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim + ) + + def _check_inputs_free_noise( + self, + prompt, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + num_frames, + ) -> None: + if not isinstance(prompt, (str, dict)): + raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}") + + if negative_prompt is not None: + if not isinstance(negative_prompt, (str, dict)): + raise ValueError( + f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}" + ) + + if prompt_embeds is not None or negative_prompt_embeds is not None: + raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.") + + frame_indices = [isinstance(x, int) for x in prompt.keys()] + frame_prompts = [isinstance(x, str) for x in prompt.values()] + min_frame = min(list(prompt.keys())) + max_frame = max(list(prompt.keys())) + + if not all(frame_indices): + raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.") + if not all(frame_prompts): + raise ValueError("Expected str values in `prompt` dict for FreeNoise.") + if min_frame != 0: + raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.") + if max_frame >= num_frames: + raise ValueError( + f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing." + ) + + def _encode_prompt_free_noise( + self, + prompt: Union[str, Dict[int, str]], + num_frames: int, + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, Dict[int, str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ) -> torch.Tensor: + if negative_prompt is None: + negative_prompt = "" + + # Ensure that we have a dictionary of prompts + if isinstance(prompt, str): + prompt = {0: prompt} + if isinstance(negative_prompt, str): + negative_prompt = {0: negative_prompt} + + self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames) + + # Sort the prompts based on frame indices + prompt = dict(sorted(prompt.items())) + negative_prompt = dict(sorted(negative_prompt.items())) + + # Ensure that we have a prompt for the last frame index + prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]] + negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]] + + frame_indices = list(prompt.keys()) + frame_prompts = list(prompt.values()) + frame_negative_indices = list(negative_prompt.keys()) + frame_negative_prompts = list(negative_prompt.values()) + + # Generate and interpolate positive prompts + prompt_embeds, _ = self.encode_prompt( + prompt=frame_prompts, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=False, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + shape = (num_frames, *prompt_embeds.shape[1:]) + prompt_interpolation_embeds = prompt_embeds.new_zeros(shape) + + for i in range(len(frame_indices) - 1): + start_frame = frame_indices[i] + end_frame = frame_indices[i + 1] + start_tensor = prompt_embeds[i].unsqueeze(0) + end_tensor = prompt_embeds[i + 1].unsqueeze(0) + + prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback( + start_frame, end_frame, start_tensor, end_tensor + ) + + # Generate and interpolate negative prompts + negative_prompt_embeds = None + negative_prompt_interpolation_embeds = None + + if do_classifier_free_guidance: + _, negative_prompt_embeds = self.encode_prompt( + prompt=[""] * len(frame_negative_prompts), + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=True, + negative_prompt=frame_negative_prompts, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape) + + for i in range(len(frame_negative_indices) - 1): + start_frame = frame_negative_indices[i] + end_frame = frame_negative_indices[i + 1] + start_tensor = negative_prompt_embeds[i].unsqueeze(0) + end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) + + negative_prompt_interpolation_embeds[ + start_frame : end_frame + 1 + ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) + + prompt_embeds = prompt_interpolation_embeds + negative_prompt_embeds = negative_prompt_interpolation_embeds + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds, negative_prompt_embeds + + def _prepare_latents_free_noise( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + context_num_frames = ( + self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames + ) + + shape = ( + batch_size, + num_channels_latents, + context_num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if self._free_noise_noise_type == "random": + return latents + else: + if latents.size(2) == num_frames: + return latents + elif latents.size(2) != self._free_noise_context_length: + raise ValueError( + f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}" + ) + latents = latents.to(device) + + if self._free_noise_noise_type == "shuffle_context": + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + current_start = i + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + + elif self._free_noise_noise_type == "repeat_context": + num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length + latents = torch.cat([latents] * num_repeats, dim=2) + + latents = latents[:, :, :num_frames] + return latents + + def _lerp( + self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor + ) -> torch.Tensor: + num_indices = end_index - start_index + 1 + interpolated_tensors = [] + + for i in range(num_indices): + alpha = i / (num_indices - 1) + interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor + interpolated_tensors.append(interpolated_tensor) + + interpolated_tensors = torch.cat(interpolated_tensors) + return interpolated_tensors + + def enable_free_noise( + self, + context_length: Optional[int] = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + noise_type: str = "shuffle_context", + prompt_interpolation_callback: Optional[ + Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor] + ] = None, + ) -> None: + r""" + Enable long video generation using FreeNoise. + + Args: + context_length (`int`, defaults to `16`, *optional*): + The number of video frames to process at once. It's recommended to set this to the maximum frames the + Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion + adapter config is used. + context_stride (`int`, *optional*): + Long videos are generated by processing many frames. FreeNoise processes these frames in sliding + windows of size `context_length`. Context stride allows you to specify how many frames to skip between + each window. For example, a context length of 16 and context stride of 4 would process 24 frames as: + [0, 15], [4, 19], [8, 23] (0-based indexing) + weighting_scheme (`str`, defaults to `pyramid`): + Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting + schemes are supported currently: + - "flat" + Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1]. + - "pyramid" + Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + - "delayed_reverse_sawtooth" + Performs weighted averaging with low weights for earlier frames and high-to-low weights for + later frames: [0.01, 0.01, 3, 2, 1]. + noise_type (`str`, defaults to "shuffle_context"): + Must be one of ["shuffle_context", "repeat_context", "random"]. + - "shuffle_context" + Shuffles a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. This is usually the best setting for most generation scenarious. However, there + might be visible repetition noticeable in the kinds of motion/animation generated. + - "repeated_context" + Repeats a fixed batch of `context_length` latents to create a final latent of size + `num_frames`. + - "random" + The final latents are random without any repetition. + """ + + allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"] + allowed_noise_type = ["shuffle_context", "repeat_context", "random"] + + if context_length > self.motion_adapter.config.motion_max_seq_length: + logger.warning( + f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results." + ) + if weighting_scheme not in allowed_weighting_scheme: + raise ValueError( + f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}" + ) + if noise_type not in allowed_noise_type: + raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}") + + self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length + self._free_noise_context_stride = context_stride + self._free_noise_weighting_scheme = weighting_scheme + self._free_noise_noise_type = noise_type + self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp + + if hasattr(self.unet.mid_block, "motion_modules"): + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + else: + blocks = [*self.unet.down_blocks, *self.unet.up_blocks] + + for block in blocks: + self._enable_free_noise_in_block(block) + + def disable_free_noise(self) -> None: + r"""Disable the FreeNoise sampling mechanism.""" + self._free_noise_context_length = None + + if hasattr(self.unet.mid_block, "motion_modules"): + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + else: + blocks = [*self.unet.down_blocks, *self.unet.up_blocks] + + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + self._disable_free_noise_in_block(block) + + def _enable_split_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int + ) -> None: + for motion_module in motion_modules: + motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) + + for i in range(len(motion_module.transformer_blocks)): + motion_module.transformer_blocks[i] = SplitInferenceModule( + motion_module.transformer_blocks[i], + spatial_split_size, + 0, + ["hidden_states", "encoder_hidden_states"], + ) + + motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) + + def _enable_split_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_split_size: int + ) -> None: + for i in range(len(attentions)): + attentions[i] = SplitInferenceModule( + attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] + ) + + def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: + for i in range(len(resnets)): + resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) + + def _enable_split_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int + ) -> None: + for i in range(len(samplers)): + samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) + + def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: + r""" + Enable FreeNoise memory optimizations by utilizing + [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. + + Args: + spatial_split_size (`int`, defaults to `256`): + The split size across spatial dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion + modeling blocks. + temporal_split_size (`int`, defaults to `16`): + The split size across temporal dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial + attention, resnets, downsampling and upsampling blocks. + """ + # TODO(aryan): Discuss on what's the best way to provide more control to users + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + if getattr(block, "motion_modules", None) is not None: + self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) + if getattr(block, "attentions", None) is not None: + self._enable_split_inference_attentions_(block.attentions, temporal_split_size) + if getattr(block, "resnets", None) is not None: + self._enable_split_inference_resnets_(block.resnets, temporal_split_size) + if getattr(block, "downsamplers", None) is not None: + self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) + if getattr(block, "upsamplers", None) is not None: + self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) + + @property + def free_noise_enabled(self): + return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None diff --git a/diffusers/src/diffusers/pipelines/hunyuandit/__init__.py b/diffusers/src/diffusers/pipelines/hunyuandit/__init__.py new file mode 100755 index 0000000..8337399 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/hunyuandit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuandit"] = ["HunyuanDiTPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuandit import HunyuanDiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/diffusers/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py new file mode 100755 index 0000000..86089ab --- /dev/null +++ b/diffusers/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -0,0 +1,900 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanDiTPipeline + + >>> pipe = HunyuanDiTPipeline.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese + >>> # prompt = "An astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + text_encoder_2=T5EncoderModel, + tokenizer_2=MT5Tokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/i2vgen_xl/__init__.py b/diffusers/src/diffusers/pipelines/i2vgen_xl/__init__.py new file mode 100755 index 0000000..b24a7e4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/i2vgen_xl/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_i2vgen_xl"] = ["I2VGenXLPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_i2vgen_xl import I2VGenXLPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/diffusers/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py new file mode 100755 index 0000000..f528b60 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -0,0 +1,783 @@ +# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL +from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet +from ...schedulers import DDIMScheduler +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import I2VGenXLPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> pipeline = I2VGenXLPipeline.from_pretrained( + ... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipeline.enable_model_cpu_offload() + + >>> image_url = ( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" + ... ) + >>> image = load_image(image_url).convert("RGB") + + >>> prompt = "Papers were floating in the air on a table in the library" + >>> negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" + >>> generator = torch.manual_seed(8888) + + >>> frames = pipeline( + ... prompt=prompt, + ... image=image, + ... num_inference_steps=50, + ... negative_prompt=negative_prompt, + ... guidance_scale=9.0, + ... generator=generator, + ... ).frames[0] + >>> video_path = export_to_gif(frames, "i2v.gif") + ``` +""" + + +@dataclass +class I2VGenXLPipelineOutput(BaseOutput): + r""" + Output class for image-to-video pipeline. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised + PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)` + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] + + +class I2VGenXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, +): + r""" + Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`I2VGenXLUNet`]): + A [`I2VGenXLUNet`] to denoise the encoded video latents. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + image_encoder: CLIPVisionModelWithProjection, + feature_extractor: CLIPImageProcessor, + unet: I2VGenXLUNet, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + # `do_resize=False` as we do custom resizing. + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + # Apply clip_skip to negative prompt embeds + if clip_skip is None: + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + else: + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + negative_prompt_embeds = negative_prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + negative_prompt_embeds = self.text_encoder.text_model.final_layer_norm(negative_prompt_embeds) + + if self.do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def _encode_image(self, image, device, num_videos_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) + + # Normalize the image with CLIP training stats. + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if self.do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def decode_latents(self, latents, decode_chunk_size=None): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + if decode_chunk_size is not None: + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + frame = self.vae.decode(latents[i : i + decode_chunk_size]).sample + frames.append(frame) + image = torch.cat(frames, dim=0) + else: + image = self.vae.decode(latents).sample + + decode_shape = (batch_size, num_frames, -1) + image.shape[2:] + video = image[None, :].reshape(decode_shape).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + def prepare_image_latents( + self, + image, + device, + num_frames, + num_videos_per_prompt, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.sample() + image_latents = image_latents * self.vae.config.scaling_factor + + # Add frames dimension to image latents + image_latents = image_latents.unsqueeze(2) + + # Append a position mask for each subsequent frame + # after the intial image latent frame + frame_position_mask = [] + for frame_idx in range(num_frames - 1): + scale = (frame_idx + 1) / (num_frames - 1) + frame_position_mask.append(torch.ones_like(image_latents[:, :, :1]) * scale) + if frame_position_mask: + frame_position_mask = torch.cat(frame_position_mask, dim=2) + image_latents = torch.cat([image_latents, frame_position_mask], dim=2) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1) + + if self.do_classifier_free_guidance: + image_latents = torch.cat([image_latents] * 2) + + return image_latents + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = 704, + width: Optional[int] = 1280, + target_fps: Optional[int] = 16, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + decode_chunk_size: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = 1, + ): + r""" + The call function to the pipeline for image-to-video generation with [`I2VGenXLPipeline`]. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + target_fps (`int`, *optional*): + Frames per second. The rate at which the generated images shall be exported to a video after + generation. This is also used as a "micro-condition" while generation. + num_frames (`int`, *optional*): + The number of video frames to generate. + num_inference_steps (`int`, *optional*): + The number of denoising steps. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + num_videos_per_prompt (`int`, *optional*): + The number of images to generate per prompt. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal + consistency between frames, but also the higher the memory consumption. By default, the decoder will + decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = guidance_scale + + # 3.1 Encode input text prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 3.2 Encode image prompt + # 3.2.1 Image encodings. + # https://github.com/ali-vilab/i2vgen-xl/blob/2539c9262ff8a2a22fa9daecbfd13f0a2dbc32d0/tools/inferences/inference_i2vgen_entrance.py#L114 + cropped_image = _center_crop_wide(image, (width, width)) + cropped_image = _resize_bilinear( + cropped_image, (self.feature_extractor.crop_size["width"], self.feature_extractor.crop_size["height"]) + ) + image_embeddings = self._encode_image(cropped_image, device, num_videos_per_prompt) + + # 3.2.2 Image latents. + resized_image = _center_crop_wide(image, (width, height)) + image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) + image_latents = self.prepare_image_latents( + image, + device=device, + num_frames=num_frames, + num_videos_per_prompt=num_videos_per_prompt, + ) + + # 3.3 Prepare additional conditions for the UNet. + if self.do_classifier_free_guidance: + fps_tensor = torch.tensor([target_fps, target_fps]).to(device) + else: + fps_tensor = torch.tensor([target_fps]).to(device) + fps_tensor = fps_tensor.repeat(batch_size * num_videos_per_prompt, 1).ravel() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + fps=fps_tensor, + image_latents=image_latents, + image_embeddings=image_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + batch_size, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(batch_size, frames, channel, width, height).permute(0, 2, 1, 3, 4) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return I2VGenXLPipelineOutput(frames=video) + + +# The following utilities are taken and adapted from +# https://github.com/ali-vilab/i2vgen-xl/blob/main/utils/transforms.py. + + +def _convert_pt_to_pil(image: Union[torch.Tensor, List[torch.Tensor]]): + if isinstance(image, list) and isinstance(image[0], torch.Tensor): + image = torch.cat(image, 0) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.unsqueeze(0) + + image_numpy = VaeImageProcessor.pt_to_numpy(image) + image_pil = VaeImageProcessor.numpy_to_pil(image_numpy) + image = image_pil + + return image + + +def _resize_bilinear( + image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int] +): + # First convert the images to PIL in case they are float tensors (only relevant for tests now). + image = _convert_pt_to_pil(image) + + if isinstance(image, list): + image = [u.resize(resolution, PIL.Image.BILINEAR) for u in image] + else: + image = image.resize(resolution, PIL.Image.BILINEAR) + return image + + +def _center_crop_wide( + image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int] +): + # First convert the images to PIL in case they are float tensors (only relevant for tests now). + image = _convert_pt_to_pil(image) + + if isinstance(image, list): + scale = min(image[0].size[0] / resolution[0], image[0].size[1] / resolution[1]) + image = [u.resize((round(u.width // scale), round(u.height // scale)), resample=PIL.Image.BOX) for u in image] + + # center crop + x1 = (image[0].width - resolution[0]) // 2 + y1 = (image[0].height - resolution[1]) // 2 + image = [u.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) for u in image] + return image + else: + scale = min(image.size[0] / resolution[0], image.size[1] / resolution[1]) + image = image.resize((round(image.width // scale), round(image.height // scale)), resample=PIL.Image.BOX) + x1 = (image.width - resolution[0]) // 2 + y1 = (image.height - resolution[1]) // 2 + image = image.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) + return image diff --git a/diffusers/src/diffusers/pipelines/kandinsky/__init__.py b/diffusers/src/diffusers/pipelines/kandinsky/__init__.py new file mode 100755 index 0000000..606f7b3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky/__init__.py @@ -0,0 +1,66 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky"] = ["KandinskyPipeline"] + _import_structure["pipeline_kandinsky_combined"] = [ + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyInpaintCombinedPipeline", + ] + _import_structure["pipeline_kandinsky_img2img"] = ["KandinskyImg2ImgPipeline"] + _import_structure["pipeline_kandinsky_inpaint"] = ["KandinskyInpaintPipeline"] + _import_structure["pipeline_kandinsky_prior"] = ["KandinskyPriorPipeline", "KandinskyPriorPipelineOutput"] + _import_structure["text_encoder"] = ["MultilingualCLIP"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_kandinsky import KandinskyPipeline + from .pipeline_kandinsky_combined import ( + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyInpaintCombinedPipeline, + ) + from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline + from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline + from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput + from .text_encoder import MultilingualCLIP + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py new file mode 100755 index 0000000..b2041e1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -0,0 +1,407 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import torch +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler, DDPMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +class KandinskyPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + movq: VQModel, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=77, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image_embeds: Union[torch.Tensor, List[torch.Tensor]], + negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py new file mode 100755 index 0000000..fe99097 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py @@ -0,0 +1,817 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Union + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, + XLMRobertaTokenizer, +) + +from ...models import PriorTransformer, UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler, DDPMScheduler, UnCLIPScheduler +from ...utils import ( + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from .pipeline_kandinsky import KandinskyPipeline +from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline +from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline +from .pipeline_kandinsky_prior import KandinskyPriorPipeline +from .text_encoder import MultilingualCLIP + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipe = AutoPipelineForText2Image.from_pretrained( + "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" + + image = pipe(prompt=prompt, num_inference_steps=25).images[0] + ``` +""" + +IMAGE2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForImage2Image + import torch + import requests + from io import BytesIO + from PIL import Image + import os + + pipe = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + image.thumbnail((768, 768)) + + image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0] + ``` +""" + +INPAINT_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForInpainting + from diffusers.utils import load_image + import torch + import numpy as np + + pipe = AutoPipelineForInpainting.from_pretrained( + "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + original_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + ) + + mask = np.zeros((768, 768), dtype=np.float32) + # Let's mask out an area above the cat's head + mask[:250, 250:-250] = 1 + + image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0] + ``` +""" + + +class KandinskyCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _load_connected_pipes = True + model_cpu_offload_seq = "text_encoder->unet->movq->prior_prior->prior_image_encoder->prior_text_encoder" + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyPriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + outputs = self.decoder_pipe( + prompt=prompt, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + ) + + self.maybe_free_model_hooks() + + return outputs + + +class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _load_connected_pipes = True + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq" + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyPriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyImg2ImgPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(IMAGE2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + strength: float = 0.3, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(prompt, PIL.Image.Image) else image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + outputs = self.decoder_pipe( + prompt=prompt, + image=image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + strength=strength, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + ) + + self.maybe_free_model_hooks() + + return outputs + + +class KandinskyInpaintCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _load_connected_pipes = True + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq" + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyPriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyInpaintPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(INPAINT_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + mask_image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + mask_image (`np.array`): + Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while + black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single + channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, + so the expected shape would be `(B, H, W, 1)`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(prompt, PIL.Image.Image) else image + mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + if ( + isinstance(mask_image, (list, tuple)) + and len(mask_image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(mask_image) == 0 + ): + mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image + + outputs = self.decoder_pipe( + prompt=prompt, + image=image, + mask_image=mask_image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + ) + + self.maybe_free_model_hooks() + + return outputs diff --git a/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py new file mode 100755 index 0000000..ef5241f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -0,0 +1,500 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from PIL import Image +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "A red cartoon frog, 4k" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyImg2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/frog.png" + ... ) + + >>> image = pipe( + ... prompt, + ... image=init_image, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... strength=0.2, + ... ).images + + >>> image[0].save("red_frog.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class KandinskyImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + + shape = latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + latents = self.add_noise(latents, noise, latent_timestep) + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + + return noisy_samples + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + image_embeds: torch.Tensor, + negative_image_embeds: torch.Tensor, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + strength: float = 0.3, + guidance_scale: float = 7.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + # 1. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. get text and image embeddings + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # 3. pre-processing initial image + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + # the formular to calculate timestep for add_noise is taken from the original kandinsky repo + latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 + + latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # 5. Create initial latent + latents = self.prepare_latents( + latents, + latent_timestep, + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + self.scheduler, + ) + + # 6. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 7. post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py new file mode 100755 index 0000000..778b6e3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -0,0 +1,635 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from packaging import version +from PIL import Image +from transformers import ( + XLMRobertaTokenizer, +) + +from ... import __version__ +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + >>> import numpy as np + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "a hat" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyInpaintPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> mask = np.zeros((768, 768), dtype=np.float32) + >>> mask[:250, 250:-250] = 1 + + >>> out = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ) + + >>> image = out.images[0] + >>> image.save("cat_with_hat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +def prepare_mask(masks): + prepared_masks = [] + for mask in masks: + old_mask = deepcopy(mask) + for i in range(mask.shape[1]): + for j in range(mask.shape[2]): + if old_mask[0][i][j] == 1: + continue + if i != 0: + mask[:, i - 1, j] = 0 + if j != 0: + mask[:, i, j - 1] = 0 + if i != 0 and j != 0: + mask[:, i - 1, j - 1] = 0 + if i != mask.shape[1] - 1: + mask[:, i + 1, j] = 0 + if j != mask.shape[2] - 1: + mask[:, i, j + 1] = 0 + if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: + mask[:, i + 1, j + 1] = 0 + prepared_masks.append(mask) + return torch.stack(prepared_masks, dim=0) + + +def prepare_mask_and_masked_image(image, mask, height, width): + r""" + Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will + be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + mask = 1 - mask + + return mask, image + + +class KandinskyInpaintPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image inpainting using Kandinsky2.1 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + movq=movq, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + self._warn_has_been_called = False + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, PIL.Image.Image], + mask_image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], + image_embeds: torch.Tensor, + negative_image_embeds: torch.Tensor, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image` or `np.ndarray`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`,`torch.Tensor` or `np.ndarray`): + `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the + image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the + expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL + image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it + will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected + shape is `(H, W)`. + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse( + "0.23.0.dev0" + ): + logger.warning( + "Please note that the expected format of `mask_image` has recently been changed. " + "Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. " + "As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. " + "This way, Kandinsky's masking behavior is aligned with Stable Diffusion. " + "THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. " + "This warning will be surpressed after the first inference call and will be removed in diffusers>0.23.0" + ) + self._warn_has_been_called = True + + # Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # preprocess image and mask + mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + + image = image.to(dtype=prompt_embeds.dtype, device=device) + image = self.movq.encode(image)["latents"] + + mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device) + + image_shape = tuple(image.shape[-2:]) + mask_image = F.interpolate( + mask_image, + image_shape, + mode="nearest", + ) + mask_image = prepare_mask(mask_image) + masked_image = image * mask_image + + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + mask_image = mask_image.repeat(2, 1, 1, 1) + masked_image = masked_image.repeat(2, 1, 1, 1) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.movq.config.latent_channels + + # get h, w for latents + sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, sample_height, sample_width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # Check that sizes of mask, masked image and latents match with expected + num_channels_mask = mask_image.shape[1] + num_channels_masked_image = masked_image.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py new file mode 100755 index 0000000..b5152d7 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -0,0 +1,547 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...schedulers import UnCLIPScheduler +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline + >>> from diffusers.utils import load_image + >>> import PIL + + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) + >>> pipe.to("cuda") + + >>> image = pipe( + ... "", + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=150, + ... ).images[0] + + >>> image.save("starry_cat.png") + ``` +""" + + +@dataclass +class KandinskyPriorPipelineOutput(BaseOutput): + """ + Output class for KandinskyPriorPipeline. + + Args: + image_embeds (`torch.Tensor`) + clip image embeddings for text prompt + negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) + clip image embeddings for unconditional tokens + """ + + image_embeds: Union[torch.Tensor, np.ndarray] + negative_image_embeds: Union[torch.Tensor, np.ndarray] + + +class KandinskyPriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _exclude_from_cpu_offload = ["prior"] + model_cpu_offload_seq = "text_encoder->prior" + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: List[Union[str, PIL.Image.Image, torch.Tensor]], + weights: List[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + negative_prior_prompt: Optional[str] = None, + negative_prompt: str = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`List[Union[str, PIL.Image.Image, torch.Tensor]]`): + list of prompts and images to guide the image generation. + weights: (`List[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + if isinstance(cond, PIL.Image.Image): + cond = ( + self.image_processor(cond, return_tensors="pt") + .pixel_values[0] + .unsqueeze(0) + .to(dtype=self.image_encoder.dtype, device=device) + ) + + image_emb = self.image_encoder(cond)["image_embeds"] + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) + + out_zero = self( + negative_prompt, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ) + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + + self.maybe_free_model_hooks() + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.prior_hook.offload() + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/diffusers/src/diffusers/pipelines/kandinsky/text_encoder.py b/diffusers/src/diffusers/pipelines/kandinsky/text_encoder.py new file mode 100755 index 0000000..caa0029 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky/text_encoder.py @@ -0,0 +1,27 @@ +import torch +from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel + + +class MCLIPConfig(XLMRobertaConfig): + model_type = "M-CLIP" + + def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): + self.transformerDimensions = transformerDimSize + self.numDims = imageDimSize + super().__init__(**kwargs) + + +class MultilingualCLIP(PreTrainedModel): + config_class = MCLIPConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.transformer = XLMRobertaModel(config) + self.LinearTransformation = torch.nn.Linear( + in_features=config.transformerDimensions, out_features=config.numDims + ) + + def forward(self, input_ids, attention_mask): + embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] + embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] + return self.LinearTransformation(embs2), embs diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/__init__.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/__init__.py new file mode 100755 index 0000000..67e97f1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/__init__.py @@ -0,0 +1,70 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky2_2"] = ["KandinskyV22Pipeline"] + _import_structure["pipeline_kandinsky2_2_combined"] = [ + "KandinskyV22CombinedPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22InpaintCombinedPipeline", + ] + _import_structure["pipeline_kandinsky2_2_controlnet"] = ["KandinskyV22ControlnetPipeline"] + _import_structure["pipeline_kandinsky2_2_controlnet_img2img"] = ["KandinskyV22ControlnetImg2ImgPipeline"] + _import_structure["pipeline_kandinsky2_2_img2img"] = ["KandinskyV22Img2ImgPipeline"] + _import_structure["pipeline_kandinsky2_2_inpainting"] = ["KandinskyV22InpaintPipeline"] + _import_structure["pipeline_kandinsky2_2_prior"] = ["KandinskyV22PriorPipeline"] + _import_structure["pipeline_kandinsky2_2_prior_emb2emb"] = ["KandinskyV22PriorEmb2EmbPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky2_2 import KandinskyV22Pipeline + from .pipeline_kandinsky2_2_combined import ( + KandinskyV22CombinedPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22InpaintCombinedPipeline, + ) + from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline + from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline + from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline + from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline + from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline + from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py new file mode 100755 index 0000000..471db61 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py @@ -0,0 +1,320 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import torch + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import deprecate, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") + >>> pipe_prior.to("cuda") + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> zero_image_emb = out.negative_image_embeds + >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images + >>> image[0].save("cat.png") + ``` +""" + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class KandinskyV22Pipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeds: Union[torch.Tensor, List[torch.Tensor]], + negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] * num_images_per_prompt + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if self.do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_channels_latents = self.unet.config.in_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeds = callback_outputs.pop("image_embeds", image_embeds) + negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if not output_type == "latent": + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py new file mode 100755 index 0000000..68334fe --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -0,0 +1,854 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer, UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler, UnCLIPScheduler +from ...utils import deprecate, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from .pipeline_kandinsky2_2 import KandinskyV22Pipeline +from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline +from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline +from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipe = AutoPipelineForText2Image.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" + + image = pipe(prompt=prompt, num_inference_steps=25).images[0] + ``` +""" + +IMAGE2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForImage2Image + import torch + import requests + from io import BytesIO + from PIL import Image + import os + + pipe = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + image.thumbnail((768, 768)) + + image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0] + ``` +""" + +INPAINT_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForInpainting + from diffusers.utils import load_image + import torch + import numpy as np + + pipe = AutoPipelineForInpainting.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + original_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + ) + + mask = np.zeros((768, 768), dtype=np.float32) + # Let's mask out an area above the cat's head + mask[:250, 250:-250] = 1 + + image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0] + ``` +""" + + +class KandinskyV22CombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + prior_image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" + _load_connected_pipes = True + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyV22PriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyV22Pipeline( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference of the prior pipeline. + The function is called with the following arguments: `prior_callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your prior pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference of the decoder pipeline. + The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, + step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors + as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + outputs = self.decoder_pipe( + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + self.maybe_free_model_hooks() + + return outputs + + +class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + prior_image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" + _load_connected_pipes = True + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyV22PriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyV22Img2ImgPipeline( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(IMAGE2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(image, PIL.Image.Image) else image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + outputs = self.decoder_pipe( + image=image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + strength=strength, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self.maybe_free_model_hooks() + return outputs + + +class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for inpainting generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + prior_image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" + _load_connected_pipes = True + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyV22PriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyV22InpaintPipeline( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(INPAINT_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + mask_image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + mask_image (`np.array`): + Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while + black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single + channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, + so the expected shape would be `(B, H, W, 1)`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_kwargs = {} + if kwargs.get("prior_callback", None) is not None: + prior_kwargs["callback"] = kwargs.pop("prior_callback") + deprecate( + "prior_callback", + "1.0.0", + "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + if kwargs.get("prior_callback_steps", None) is not None: + deprecate( + "prior_callback_steps", + "1.0.0", + "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps") + + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + **prior_kwargs, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(image, PIL.Image.Image) else image + mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + if ( + isinstance(mask_image, (list, tuple)) + and len(mask_image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(mask_image) == 0 + ): + mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image + + outputs = self.decoder_pipe( + image=image, + mask_image=mask_image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + **kwargs, + ) + self.maybe_free_model_hooks() + + return outputs diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py new file mode 100755 index 0000000..0130c39 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py @@ -0,0 +1,320 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import torch + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + logging, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import numpy as np + + >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline + >>> from transformers import pipeline + >>> from diffusers.utils import load_image + + + >>> def make_hint(image, depth_estimator): + ... image = depth_estimator(image)["depth"] + ... image = np.array(image) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... detected_map = torch.from_numpy(image).float() / 255.0 + ... hint = detected_map.permute(2, 0, 1) + ... return hint + + + >>> depth_estimator = pipeline("depth-estimation") + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior = pipe_prior.to("cuda") + + >>> pipe = KandinskyV22ControlnetPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((768, 768)) + + >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") + + >>> prompt = "A robot, 4k photo" + >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" + + >>> generator = torch.Generator(device="cuda").manual_seed(43) + + >>> image_emb, zero_image_emb = pipe_prior( + ... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator + ... ).to_tuple() + + >>> images = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... hint=hint, + ... num_inference_steps=50, + ... generator=generator, + ... height=768, + ... width=768, + ... ).images + + >>> images[0].save("robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class KandinskyV22ControlnetPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image_embeds: Union[torch.Tensor, List[torch.Tensor]], + negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], + hint: torch.Tensor, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + hint (`torch.Tensor`): + The controlnet condition. + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + if isinstance(hint, list): + hint = torch.cat(hint, dim=0) + + batch_size = image_embeds.shape[0] * num_images_per_prompt + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hint = hint.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.movq.config.latent_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + # Offload all models + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py new file mode 100755 index 0000000..12be153 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py @@ -0,0 +1,381 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from PIL import Image + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + logging, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import numpy as np + + >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline + >>> from transformers import pipeline + >>> from diffusers.utils import load_image + + + >>> def make_hint(image, depth_estimator): + ... image = depth_estimator(image)["depth"] + ... image = np.array(image) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... detected_map = torch.from_numpy(image).float() / 255.0 + ... hint = detected_map.permute(2, 0, 1) + ... return hint + + + >>> depth_estimator = pipeline("depth-estimation") + + >>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior = pipe_prior.to("cuda") + + >>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((768, 768)) + + + >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") + + >>> prompt = "A robot, 4k photo" + >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" + + >>> generator = torch.Generator(device="cuda").manual_seed(43) + + >>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator) + >>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator) + + >>> images = pipe( + ... image=img, + ... strength=0.5, + ... image_embeds=img_emb.image_embeds, + ... negative_image_embeds=negative_emb.image_embeds, + ... hint=hint, + ... num_inference_steps=50, + ... generator=generator, + ... height=768, + ... width=768, + ... ).images + + >>> images[0].save("robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + @torch.no_grad() + def __call__( + self, + image_embeds: Union[torch.Tensor, List[torch.Tensor]], + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], + hint: torch.Tensor, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + hint (`torch.Tensor`): + The controlnet condition. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + if isinstance(hint, list): + hint = torch.cat(hint, dim=0) + + batch_size = image_embeds.shape[0] + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hint = hint.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) + + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) + image = image.to(dtype=image_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator + ) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + # Offload all models + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py new file mode 100755 index 0000000..899273a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py @@ -0,0 +1,399 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from PIL import Image + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "A red cartoon frog, 4k" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/frog.png" + ... ) + + >>> image = pipe( + ... image=init_image, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... strength=0.2, + ... ).images + + >>> image[0].save("red_frog.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class KandinskyV22Img2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + image_embeds: Union[torch.Tensor, List[torch.Tensor]], + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], + negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if self.do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) + image = image.to(dtype=image_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator + ) + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeds = callback_outputs.pop("image_embeds", image_embeds) + negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py new file mode 100755 index 0000000..b5ba7a0 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py @@ -0,0 +1,556 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from packaging import version +from PIL import Image + +from ... import __version__ +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + >>> import numpy as np + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "a hat" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyV22InpaintPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> mask = np.zeros((768, 768), dtype=np.float32) + >>> mask[:250, 250:-250] = 1 + + >>> out = pipe( + ... image=init_image, + ... mask_image=mask, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ) + + >>> image = out.images[0] + >>> image.save("cat_with_hat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask +def prepare_mask(masks): + prepared_masks = [] + for mask in masks: + old_mask = deepcopy(mask) + for i in range(mask.shape[1]): + for j in range(mask.shape[2]): + if old_mask[0][i][j] == 1: + continue + if i != 0: + mask[:, i - 1, j] = 0 + if j != 0: + mask[:, i, j - 1] = 0 + if i != 0 and j != 0: + mask[:, i - 1, j - 1] = 0 + if i != mask.shape[1] - 1: + mask[:, i + 1, j] = 0 + if j != mask.shape[2] - 1: + mask[:, i, j + 1] = 0 + if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: + mask[:, i + 1, j + 1] = 0 + prepared_masks.append(mask) + return torch.stack(prepared_masks, dim=0) + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width): + r""" + Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will + be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + mask = 1 - mask + + return mask, image + + +class KandinskyV22InpaintPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image inpainting using Kandinsky2.1 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds", "masked_image", "mask_image"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + self._warn_has_been_called = False + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + image_embeds: Union[torch.Tensor, List[torch.Tensor]], + image: Union[torch.Tensor, PIL.Image.Image], + mask_image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], + negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`np.array`): + Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while + black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single + channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, + so the expected shape would be `(B, H, W, 1)`. + negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse( + "0.23.0.dev0" + ): + logger.warning( + "Please note that the expected format of `mask_image` has recently been changed. " + "Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. " + "As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. " + "This way, Kandinsky's masking behavior is aligned with Stable Diffusion. " + "THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. " + "This warning will be surpressed after the first inference call and will be removed in diffusers>0.23.0" + ) + self._warn_has_been_called = True + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + self._guidance_scale = guidance_scale + + device = self._execution_device + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] * num_images_per_prompt + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if self.do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # preprocess image and mask + mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + + image = image.to(dtype=image_embeds.dtype, device=device) + image = self.movq.encode(image)["latents"] + + mask_image = mask_image.to(dtype=image_embeds.dtype, device=device) + + image_shape = tuple(image.shape[-2:]) + mask_image = F.interpolate( + mask_image, + image_shape, + mode="nearest", + ) + mask_image = prepare_mask(mask_image) + masked_image = image * mask_image + + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) + if self.do_classifier_free_guidance: + mask_image = mask_image.repeat(2, 1, 1, 1) + masked_image = masked_image.repeat(2, 1, 1, 1) + + num_channels_latents = self.movq.config.latent_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + noise = torch.clone(latents) + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + init_latents_proper = image[:1] + init_mask = mask_image[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeds = callback_outputs.pop("image_embeds", image_embeds) + negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) + masked_image = callback_outputs.pop("masked_image", masked_image) + mask_image = callback_outputs.pop("mask_image", mask_image) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py new file mode 100755 index 0000000..f2134b2 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -0,0 +1,549 @@ +from typing import Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...schedulers import UnCLIPScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..kandinsky import KandinskyPriorPipelineOutput +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") + >>> pipe_prior.to("cuda") + >>> prompt = "red cat, 4k photo" + >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple() + + >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline + >>> from diffusers.utils import load_image + >>> import PIL + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> out = pipe_prior.interpolate(images_texts, weights) + >>> pipe = KandinskyV22Pipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=out.image_embeds, + ... negative_image_embeds=out.negative_image_embeds, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images[0] + >>> image.save("starry_cat.png") + ``` +""" + + +class KandinskyV22PriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->prior" + _exclude_from_cpu_offload = ["prior"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "text_encoder_hidden_states", "text_mask"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: List[Union[str, PIL.Image.Image, torch.Tensor]], + weights: List[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + negative_prior_prompt: Optional[str] = None, + negative_prompt: str = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`List[Union[str, PIL.Image.Image, torch.Tensor]]`): + list of prompts and images to guide the image generation. + weights: (`List[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds.unsqueeze(0) + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + if isinstance(cond, PIL.Image.Image): + cond = ( + self.image_processor(cond, return_tensors="pt") + .pixel_values[0] + .unsqueeze(0) + .to(dtype=self.image_encoder.dtype, device=device) + ) + + image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0) + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0) + + out_zero = self( + negative_prompt, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ) + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + self._guidance_scale = guidance_scale + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt + ) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if self.do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + self.guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == timesteps.shape[0]: + prev_timestep = None + else: + prev_timestep = timesteps[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + text_mask = callback_outputs.pop("text_mask", text_mask) + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py new file mode 100755 index 0000000..ec6509b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py @@ -0,0 +1,563 @@ +from typing import List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...schedulers import UnCLIPScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..kandinsky import KandinskyPriorPipelineOutput +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple() + + >>> pipe = KandinskyPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16" + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline + >>> from diffusers.utils import load_image + >>> import PIL + + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + + >>> pipe = KandinskyV22Pipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=150, + ... ).images[0] + + >>> image.save("starry_cat.png") + ``` +""" + + +class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->prior" + _exclude_from_cpu_offload = ["prior"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: List[Union[str, PIL.Image.Image, torch.Tensor]], + weights: List[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + negative_prior_prompt: Optional[str] = None, + negative_prompt: str = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`List[Union[str, PIL.Image.Image, torch.Tensor]]`): + list of prompts and images to guide the image generation. + weights: (`List[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds.unsqueeze(0) + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + image_emb = self._encode_image( + cond, device=device, num_images_per_prompt=num_images_per_prompt + ).unsqueeze(0) + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0) + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb)) + + def _encode_image( + self, + image: Union[torch.Tensor, List[PIL.Image.Image]], + device, + num_images_per_prompt, + ): + if not isinstance(image, torch.Tensor): + image = self.image_processor(image, return_tensors="pt").pixel_values.to( + dtype=self.image_encoder.dtype, device=device + ) + + image_emb = self.image_encoder(image)["image_embeds"] # B, D + image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0) + image_emb.to(device=device) + + return image_emb + + def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + emb = emb.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + init_latents = emb + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], + strength: float = 0.3, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. + emb (`torch.Tensor`): + The image embedding. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if not isinstance(image, List): + image = [image] + + if isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + if isinstance(image, torch.Tensor) and image.ndim == 2: + # allow user to pass image_embeds directly + image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0) + elif isinstance(image, torch.Tensor) and image.ndim != 4: + raise ValueError( + f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}" + ) + else: + image_embeds = self._encode_image(image, device, num_images_per_prompt) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + + latents = image_embeds + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size) + latents = self.prepare_latents( + latents, + latent_timestep, + batch_size // num_images_per_prompt, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == timesteps.shape[0]: + prev_timestep = None + else: + prev_timestep = timesteps[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/diffusers/src/diffusers/pipelines/kandinsky3/__init__.py b/diffusers/src/diffusers/pipelines/kandinsky3/__init__.py new file mode 100755 index 0000000..e8a3063 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky3/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"] + _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky3 import Kandinsky3Pipeline + from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py b/diffusers/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py new file mode 100755 index 0000000..5360632 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +import argparse +import fnmatch + +from safetensors.torch import load_file + +from diffusers import Kandinsky3UNet + + +MAPPING = { + "to_time_embed.1": "time_embedding.linear_1", + "to_time_embed.3": "time_embedding.linear_2", + "in_layer": "conv_in", + "out_layer.0": "conv_norm_out", + "out_layer.2": "conv_out", + "down_samples": "down_blocks", + "up_samples": "up_blocks", + "projection_lin": "encoder_hid_proj.projection_linear", + "projection_ln": "encoder_hid_proj.projection_norm", + "feature_pooling": "add_time_condition", + "to_query": "to_q", + "to_key": "to_k", + "to_value": "to_v", + "output_layer": "to_out.0", + "self_attention_block": "attentions.0", +} + +DYNAMIC_MAP = { + "resnet_attn_blocks.*.0": "resnets_in.*", + "resnet_attn_blocks.*.1": ("attentions.*", 1), + "resnet_attn_blocks.*.2": "resnets_out.*", +} +# MAPPING = {} + + +def convert_state_dict(unet_state_dict): + """ + Args: + Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model. + unet_model (torch.nn.Module): The original U-Net model. unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet + model to match keys with. + + Returns: + OrderedDict: The converted state dictionary. + """ + # Example of renaming logic (this will vary based on your model's architecture) + converted_state_dict = {} + for key in unet_state_dict: + new_key = key + for pattern, new_pattern in MAPPING.items(): + new_key = new_key.replace(pattern, new_pattern) + + for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items(): + has_matched = False + if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched: + star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1]) + + if isinstance(dyn_new_pattern, tuple): + new_star = star + dyn_new_pattern[-1] + dyn_new_pattern = dyn_new_pattern[0] + else: + new_star = star + + pattern = dyn_pattern.replace("*", str(star)) + new_pattern = dyn_new_pattern.replace("*", str(new_star)) + + new_key = new_key.replace(pattern, new_pattern) + has_matched = True + + converted_state_dict[new_key] = unet_state_dict[key] + + return converted_state_dict + + +def main(model_path, output_path): + # Load your original U-Net model + unet_state_dict = load_file(model_path) + + # Initialize your Kandinsky3UNet model + config = {} + + # Convert the state dict + converted_state_dict = convert_state_dict(unet_state_dict) + + unet = Kandinsky3UNet(config) + unet.load_state_dict(converted_state_dict) + + unet.save_pretrained(output_path) + print(f"Converted model saved to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format") + parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model") + parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") + + args = parser.parse_args() + main(args.model_path, args.output_path) diff --git a/diffusers/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/diffusers/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py new file mode 100755 index 0000000..8dbae2a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py @@ -0,0 +1,576 @@ +from typing import Callable, Dict, List, Optional, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import Kandinsky3UNet, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + deprecate, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForText2Image + >>> import torch + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0] + ``` + +""" + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: Kandinsky3UNet, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq + ) + + def process_embeds(self, embeddings, attention_mask, cut_context): + if cut_context: + embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) + max_seq_length = attention_mask.sum(-1).max() + 1 + embeddings = embeddings[:, :max_seq_length] + attention_mask = attention_mask[:, :max_seq_length] + return embeddings, attention_mask + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + _cut_context=False, + attention_mask: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 128 + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_embeds, attention_mask = self.process_embeds(prompt_embeds, attention_mask, _cut_context) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + if negative_prompt is not None: + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=128, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] + negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] + negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) + + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_attention_mask = torch.zeros_like(attention_mask) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + if negative_prompt_embeds.shape != prompt_embeds.shape: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + negative_attention_mask = None + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 25, + guidance_scale: float = 3.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = 1024, + width: Optional[int] = 1024, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + latents=None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 3.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + cut_context = True + device = self._execution_device + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, + ) + + self._guidance_scale = guidance_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + height, width = downscale_height_and_width(height, width, 8) + + latents = self.prepare_latents( + (batch_size * num_images_per_prompt, 4, height, width), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=attention_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond + # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/diffusers/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py new file mode 100755 index 0000000..81c45c4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py @@ -0,0 +1,643 @@ +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import PIL.Image +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import Kandinsky3UNet, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + deprecate, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A painting of the inside of a subway train with tiny raccoons." + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png" + ... ) + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] + ``` +""" + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +def prepare_image(pil_image): + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->movq->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: Kandinsky3UNet, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def _process_embeds(self, embeddings, attention_mask, cut_context): + # return embeddings, attention_mask + if cut_context: + embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) + max_seq_length = attention_mask.sum(-1).max() + 1 + embeddings = embeddings[:, :max_seq_length] + attention_mask = attention_mask[:, :max_seq_length] + return embeddings, attention_mask + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + _cut_context=False, + attention_mask: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 128 + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_embeds, attention_mask = self._process_embeds(prompt_embeds, attention_mask, _cut_context) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + if negative_prompt is not None: + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=128, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] + negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] + negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) + + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_attention_mask = torch.zeros_like(attention_mask) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + if negative_prompt_embeds.shape != prompt_embeds.shape: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + negative_attention_mask = None + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, + strength: float = 0.3, + num_inference_steps: int = 25, + guidance_scale: float = 3.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 3.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + cut_context = True + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, + ) + + self._guidance_scale = guidance_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i) for i in image], dim=0) + image = image.to(dtype=prompt_embeds.dtype, device=device) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + # 5. Prepare latents + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=attention_mask, + )[0] + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" + ) + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kolors/__init__.py b/diffusers/src/diffusers/pipelines/kolors/__init__.py new file mode 100755 index 0000000..671d22e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kolors/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_sentencepiece_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) +else: + _import_structure["pipeline_kolors"] = ["KolorsPipeline"] + _import_structure["pipeline_kolors_img2img"] = ["KolorsImg2ImgPipeline"] + _import_structure["text_encoder"] = ["ChatGLMModel"] + _import_structure["tokenizer"] = ["ChatGLMTokenizer"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_and_sentencepiece_objects import * + + else: + from .pipeline_kolors import KolorsPipeline + from .pipeline_kolors_img2img import KolorsImg2ImgPipeline + from .text_encoder import ChatGLMModel + from .tokenizer import ChatGLMTokenizer + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/kolors/pipeline_kolors.py b/diffusers/src/diffusers/pipelines/kolors/pipeline_kolors.py new file mode 100755 index 0000000..b682429 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kolors/pipeline_kolors.py @@ -0,0 +1,1070 @@ +# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import KolorsPipelineOutput +from .text_encoder import ChatGLMModel +from .tokenizer import ChatGLMTokenizer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import KolorsPipeline + + >>> pipe = KolorsPipeline.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = ( + ... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'" + ... ) + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/diffusers/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py new file mode 100755 index 0000000..4985a80 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -0,0 +1,1250 @@ +# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import KolorsPipelineOutput +from .text_encoder import ChatGLMModel +from .tokenizer import ChatGLMTokenizer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import KolorsImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = KolorsImg2ImgPipeline.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = ( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/bunny_source.png" + ... ) + + + >>> init_image = load_image(url) + >>> prompt = "high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed." + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder-unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_start(self): + return self._denoising_start + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/kolors/pipeline_output.py b/diffusers/src/diffusers/pipelines/kolors/pipeline_output.py new file mode 100755 index 0000000..310ee7e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kolors/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class KolorsPipelineOutput(BaseOutput): + """ + Output class for Kolors pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/diffusers/src/diffusers/pipelines/kolors/text_encoder.py b/diffusers/src/diffusers/pipelines/kolors/text_encoder.py new file mode 100755 index 0000000..6fb6f18 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kolors/text_encoder.py @@ -0,0 +1,889 @@ +# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import LayerNorm +from torch.nn.utils import skip_init +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + classifier_dropout=None, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.classifier_dropout = classifier_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [ + k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] + ] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones( + output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool + ) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config), + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config), + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, + and project the state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype + ) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype + ) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype + ) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache + ) + else: + layer_ret = layer( + hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1 + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device) + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, + prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size), + ) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + ) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype + ) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs, + ) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) diff --git a/diffusers/src/diffusers/pipelines/kolors/tokenizer.py b/diffusers/src/diffusers/pipelines/kolors/tokenizer.py new file mode 100755 index 0000000..a7b942f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/kolors/tokenizer.py @@ -0,0 +1,334 @@ +# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from typing import Dict, List, Optional, Union + +from sentencepiece import SentencePieceProcessor +from transformers import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy + + +class SPTokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.unk_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens]) + + def tokenize(self, s: str, encode_special_tokens=False): + if encode_special_tokens: + last_index = 0 + t = [] + for match in re.finditer(self.role_special_token_expression, s): + if last_index < match.start(): + t.extend(self.sp_model.EncodeAsPieces(s[last_index : match.start()])) + t.append(s[match.start() : match.end()]) + last_index = match.end() + if last_index < len(s): + t.extend(self.sp_model.EncodeAsPieces(s[last_index:])) + return t + else: + return self.sp_model.EncodeAsPieces(s) + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert isinstance(s, str) + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + text, buffer = "", [] + for token in t: + if token in self.index_special_tokens: + if buffer: + text += self.sp_model.decode(buffer) + buffer = [] + text += self.index_special_tokens[token] + else: + buffer.append(token) + if buffer: + text += self.sp_model.decode(buffer) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return self.index_special_tokens[index] + if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: + return "" + return self.sp_model.IdToPiece(index) + + +class ChatGLMTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "tokenizer.model"} + + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__( + self, + vocab_file, + padding_side="left", + clean_up_tokenization_spaces=False, + encode_special_tokens=False, + **kwargs, + ): + self.name = "GLMTokenizer" + + self.vocab_file = vocab_file + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + "": self.tokenizer.bos_id, + "": self.tokenizer.eos_id, + "": self.tokenizer.pad_id, + } + self.encode_special_tokens = encode_special_tokens + super().__init__( + padding_side=padding_side, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + encode_special_tokens=encode_special_tokens, + **kwargs, + ) + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" + return self.tokenizer.special_tokens[token] + + @property + def unk_token(self) -> str: + return "" + + @unk_token.setter + def unk_token(self, value: str): + self._unk_token = value + + @property + def pad_token(self) -> str: + return "" + + @pad_token.setter + def pad_token(self, value: str): + self._pad_token = value + + @property + def pad_token_id(self): + return self.get_command("") + + @property + def eos_token(self) -> str: + return "" + + @eos_token.setter + def eos_token(self, value: str): + self._eos_token = value + + @property + def eos_token_id(self): + return self.get_command("") + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) + else: + vocab_file = save_directory + + with open(self.vocab_file, "rb") as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def get_prefix_tokens(self): + prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")] + return prefix_tokens + + def build_single_message(self, role, metadata, message): + assert role in ["system", "user", "assistant", "observation"], role + role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") + message_tokens = self.tokenizer.encode(message) + tokens = role_tokens + message_tokens + return tokens + + def build_chat_input(self, query, history=None, role="user"): + if history is None: + history = [] + input_ids = [] + for item in history: + content = item["content"] + if item["role"] == "system" and "tools" in item: + content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) + input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) + input_ids.extend(self.build_single_message(role, "", query)) + input_ids.extend([self.get_command("<|assistant|>")]) + return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + prefix_tokens = self.get_prefix_tokens() + token_ids_0 = prefix_tokens + token_ids_0 + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * seq_length + + if "position_ids" not in encoded_inputs: + encoded_inputs["position_ids"] = list(range(seq_length)) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/diffusers/src/diffusers/pipelines/latent_consistency_models/__init__.py b/diffusers/src/diffusers/pipelines/latent_consistency_models/__init__.py new file mode 100755 index 0000000..8f79d3c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latent_consistency_models/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"] + _import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline + from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/diffusers/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py new file mode 100755 index 0000000..dd72d3c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -0,0 +1,976 @@ +# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import LCMScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForImage2Image + >>> import torch + >>> import PIL + + >>> pipe = AutoPipelineForImage2Image.from_pretrained("SimianLuo/LCM_Dreamshaper_v7") + >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. + >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32) + + >>> prompt = "High altitude snowy mountains" + >>> image = PIL.Image.open("./snowy_mountains.png") + + >>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps. + >>> num_inference_steps = 4 + >>> images = pipe( + ... prompt=prompt, image=image, num_inference_steps=num_inference_steps, guidance_scale=8.0 + ... ).images + + >>> images[0].save("image.png") + ``` + +""" + + +class LatentConsistencyModelImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for image-to-image generation using a latent consistency model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only + supports [`LCMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + requires_safety_checker (`bool`, *optional*, defaults to `True`): + Whether the pipeline requires a safety checker component. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: LCMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt: Union[str, List[str]], + strength: float, + callback_steps: int, + prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + @property + def do_classifier_free_guidance(self): + return False + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 4, + strength: float = 0.8, + original_inference_steps: int = None, + timesteps: List[int] = None, + guidance_scale: float = 8.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + original_inference_steps (`int`, *optional*): + The original number of inference steps use to generate a linearly-spaced timestep schedule, from which + we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, + following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the + scheduler's `original_inference_steps` attribute. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending + order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + Note that the original latent consistency models paper uses a different CFG formulation where the + guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale > + 0`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided + # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the + # unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts. + prompt_embeds, _ = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Encode image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + original_inference_steps=original_inference_steps, + strength=strength, + ) + + # 6. Prepare latent variables + original_inference_steps = ( + original_inference_steps + if original_inference_steps is not None + else self.scheduler.config.original_inference_steps + ) + latent_timestep = timesteps[:1] + if latents is None: + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + bs = batch_size * num_images_per_prompt + + # 6. Get Guidance Scale Embedding + # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper + # CFG formulation, so we need to subtract 1 from the input guidance_scale. + # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) + w = torch.tensor(self.guidance_scale - 1).repeat(bs) + w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( + device=device, dtype=latents.dtype + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 8. LCM Multistep Sampling Loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = latents.to(prompt_embeds.dtype) + + # model prediction (v-prediction, eps, x) + model_pred = self.unet( + latents, + t, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + w_embedding = callback_outputs.pop("w_embedding", w_embedding) + denoised = callback_outputs.pop("denoised", denoised) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + denoised = denoised.to(prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = denoised + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/diffusers/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py new file mode 100755 index 0000000..89cafc2 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -0,0 +1,905 @@ +# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import LCMScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import DiffusionPipeline + >>> import torch + + >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7") + >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. + >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32) + + >>> prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" + + >>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps. + >>> num_inference_steps = 4 + >>> images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0).images + >>> images[0].save("image.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LatentConsistencyModelPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using a latent consistency model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only + supports [`LCMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + requires_safety_checker (`bool`, *optional*, defaults to `True`): + Whether the pipeline requires a safety checker component. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: LCMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Currently StableDiffusionPipeline.check_inputs with negative prompt stuff removed + def check_inputs( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + callback_steps: int, + prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + @property + def do_classifier_free_guidance(self): + return False + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 4, + original_inference_steps: int = None, + timesteps: List[int] = None, + guidance_scale: float = 8.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + original_inference_steps (`int`, *optional*): + The original number of inference steps use to generate a linearly-spaced timestep schedule, from which + we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, + following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the + scheduler's `original_inference_steps` attribute. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending + order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + Note that the original latent consistency models paper uses a different CFG formulation where the + guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale > + 0`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided + # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the + # unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts. + prompt_embeds, _ = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps + ) + + # 5. Prepare latent variable + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + bs = batch_size * num_images_per_prompt + + # 6. Get Guidance Scale Embedding + # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper + # CFG formulation, so we need to subtract 1 from the input guidance_scale. + # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) + w = torch.tensor(self.guidance_scale - 1).repeat(bs) + w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( + device=device, dtype=latents.dtype + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 8. LCM MultiStep Sampling Loop: + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = latents.to(prompt_embeds.dtype) + + # model prediction (v-prediction, eps, x) + model_pred = self.unet( + latents, + t, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + w_embedding = callback_outputs.pop("w_embedding", w_embedding) + denoised = callback_outputs.pop("denoised", denoised) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + denoised = denoised.to(prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = denoised + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/latent_diffusion/__init__.py new file mode 100755 index 0000000..561f96f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latent_diffusion/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"] + _import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline + from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/diffusers/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py new file mode 100755 index 0000000..f6f3531 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -0,0 +1,746 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput +from transformers.utils import logging + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class LDMTextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + def __init__( + self, + vqvae: Union[VQModel, AutoencoderKL], + bert: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + unet: Union[UNet2DModel, UNet2DConditionModel], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # load model and scheduler + >>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> prompt = "A painting of a squirrel eating a burger" + >>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images + + >>> # save images + >>> for idx, image in enumerate(images): + ... image.save(f"squirrel-{idx}.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" + ) + negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self._execution_device))[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") + prompt_embeds = self.bert(text_input.input_ids.to(self._execution_device))[0] + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + latents_shape, generator=generator, device=self._execution_device, dtype=prompt_embeds.dtype + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self._execution_device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = prompt_embeds + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / self.vqvae.config.scaling_factor * latents + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + hidden_states (`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + _supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LDMBertEncoder,)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + _no_split_modules = [] + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs diff --git a/diffusers/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/diffusers/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py new file mode 100755 index 0000000..bb72b4d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -0,0 +1,189 @@ +import inspect +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.utils.checkpoint + +from ...models import UNet2DModel, VQModel +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class LDMSuperResolutionPipeline(DiffusionPipeline): + r""" + A pipeline for image super-resolution using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: VQModel, + unet: UNet2DModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: Union[torch.Tensor, PIL.Image.Image] = None, + batch_size: Optional[int] = 1, + num_inference_steps: Optional[int] = 100, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image` or tensor representing an image batch to be used as the starting point for the process. + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import LDMSuperResolutionPipeline + >>> import torch + + >>> # load model and scheduler + >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages") + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = ( + ... "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png" + ... ) + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + + >>> # run pipeline in inference (sample random noise and denoise) + >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1).images[0] + >>> # save image + >>> upscaled_image.save("ldm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + height, width = image.shape[-2:] + + # in_channels should be 6: 3 for latents, 3 for low resolution image + latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) + latents_dtype = next(self.unet.parameters()).dtype + + latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + + image = image.to(device=self.device, dtype=latents_dtype) + + # set timesteps and move to the correct device + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps_tensor = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(timesteps_tensor): + # concat latents and low resolution image in the channel dimension. + latents_input = torch.cat([latents, image], dim=1) + latents_input = self.scheduler.scale_model_input(latents_input, t) + # predict the noise residual + noise_pred = self.unet(latents_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # decode the image latents with the VQVAE + image = self.vqvae.decode(latents).sample + image = torch.clamp(image, -1.0, 1.0) + image = image / 2 + 0.5 + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/latte/__init__.py b/diffusers/src/diffusers/pipelines/latte/__init__.py new file mode 100755 index 0000000..4296b42 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latte/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latte"] = ["LattePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latte import LattePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/latte/pipeline_latte.py b/diffusers/src/diffusers/pipelines/latte/pipeline_latte.py new file mode 100755 index 0000000..d4bedf2 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/latte/pipeline_latte.py @@ -0,0 +1,881 @@ +# Copyright 2024 the Latte Team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKL, LatteTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + BaseOutput, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LattePipeline + >>> from diffusers.utils import export_to_gif + + >>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too. + >>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> videos = pipe(prompt).frames[0] + >>> export_to_gif(videos, "latte.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class LattePipelineOutput(BaseOutput): + frames: torch.Tensor + + +class LattePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Latte. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Latte uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`LatteTransformer3DModel`]): + A text conditioned `LatteTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + bad_punct_regex = re.compile(r"[#®•©™&@·º½¾¿¡§~\)\(\]\[\}\{\|\\/\\*]{1,}") + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: LatteTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 + else: + masked_feature = emb * mask[:, None, :, None] # 1 120 4096 + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + mask_feature: bool = True, + dtype=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Latte, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of video that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For Latte, it's should be the embeddings of the "" string. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + mask_feature: (bool, defaults to `True`): + If `True`, the function will mask the text embeddings. + """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 120 + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds_attention_mask = attention_mask + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds_attention_mask = torch.ones_like(prompt_embeds) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + # Perform additional masking. + if mask_feature and not embeds_initially_provided: + prompt_embeds = prompt_embeds.unsqueeze(1) + masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) + masked_prompt_embeds = masked_prompt_embeds.squeeze(1) + masked_negative_prompt_embeds = ( + negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None + ) + + return masked_prompt_embeds, masked_negative_prompt_embeds + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 7.5, + num_images_per_prompt: int = 1, + video_length: int = 16, + height: int = 512, + width: int = 512, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clean_caption: bool = True, + mask_feature: bool = True, + enable_temporal_attentions: bool = True, + decode_chunk_size: Optional[int] = None, + ) -> Union[LattePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + video_length (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For Latte this negative prompt should be "". If not provided, + negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + enable_temporal_attentions (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the + expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality. + For lower memory usage, reduce `decode_chunk_size`. + + Examples: + + Returns: + [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else video_length + + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + mask_feature=mask_feature, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=current_timestep, + enable_temporal_attentions=enable_temporal_attentions, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # use learned sigma? + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous video: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latents": + video = self.decode_latents(latents, video_length, decode_chunk_size=14) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LattePipelineOutput(frames=video) + + # Similar to diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.decode_latents + def decode_latents(self, latents: torch.Tensor, video_length: int, decode_chunk_size: int = 14): + # [batch, channels, frames, height, width] -> [batch*frames, channels, height, width] + latents = latents.permute(0, 2, 1, 3, 4).flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, video_length, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.float() + return frames diff --git a/diffusers/src/diffusers/pipelines/ledits_pp/__init__.py b/diffusers/src/diffusers/pipelines/ledits_pp/__init__.py new file mode 100755 index 0000000..aae3b1c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ledits_pp/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_leditspp_stable_diffusion"] = ["LEditsPPPipelineStableDiffusion"] + _import_structure["pipeline_leditspp_stable_diffusion_xl"] = ["LEditsPPPipelineStableDiffusionXL"] + + _import_structure["pipeline_output"] = ["LEditsPPDiffusionPipelineOutput", "LEditsPPDiffusionPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_leditspp_stable_diffusion import ( + LEditsPPDiffusionPipelineOutput, + LEditsPPInversionPipelineOutput, + LEditsPPPipelineStableDiffusion, + ) + from .pipeline_leditspp_stable_diffusion_xl import LEditsPPPipelineStableDiffusionXL + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py new file mode 100755 index 0000000..049b896 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py @@ -0,0 +1,1499 @@ +import inspect +import math +from itertools import repeat +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import Attention, AttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import LEditsPPPipelineStableDiffusion + >>> from diffusers.utils import load_image + + >>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png" + >>> image = load_image(img_url).convert("RGB") + + >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1) + + >>> edited_image = pipe( + ... editing_prompt=["cherry blossom"], edit_guidance_scale=10.0, edit_threshold=0.75 + ... ).images[0] + ``` +""" + + +# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.AttentionStore +class LeditsAttentionStore: + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False): + # attn.shape = batch_size * head_size, seq_len query, seq_len_key + if attn.shape[1] <= self.max_size: + bs = 1 + int(PnP) + editing_prompts + skip = 2 if PnP else 1 # skip PnP & unconditional + attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3) + source_batch_size = int(attn.shape[1] // bs) + self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet) + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + + self.step_store[key].append(attn) + + def between_steps(self, store_step=True): + if store_step: + if self.average: + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + else: + if len(self.attention_store) == 0: + self.attention_store = [self.step_store] + else: + self.attention_store.append(self.step_store) + + self.cur_step += 1 + self.step_store = self.get_empty_store() + + def get_attention(self, step: int): + if self.average: + attention = { + key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store + } + else: + assert step is not None + attention = self.attention_store[step] + return attention + + def aggregate_attention( + self, attention_maps, prompts, res: Union[int, Tuple[int]], from_where: List[str], is_cross: bool, select: int + ): + out = [[] for x in range(self.batch_size)] + if isinstance(res, int): + num_pixels = res**2 + resolution = (res, res) + else: + num_pixels = res[0] * res[1] + resolution = res[:2] + + for location in from_where: + for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + for batch, item in enumerate(bs_item): + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select] + out[batch].append(cross_maps) + + out = torch.stack([torch.cat(x, dim=0) for x in out]) + # average over heads + out = out.sum(1) / out.shape[1] + return out + + def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None): + self.step_store = self.get_empty_store() + self.attention_store = [] + self.cur_step = 0 + self.average = average + self.batch_size = batch_size + if max_size is None: + self.max_size = max_resolution**2 + elif max_size is not None and max_resolution is None: + self.max_size = max_size + else: + raise ValueError("Only allowed to set one of max_resolution or max_size") + + +# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing +class LeditsGaussianSmoothing: + def __init__(self, device): + kernel_size = [3, 3] + sigma = [0.5, 0.5] + + # The gaussian kernel is the product of the gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) + + self.weight = kernel.to(device) + + def __call__(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return F.conv2d(input, weight=self.weight.to(input.dtype)) + + +class LEDITSCrossAttnProcessor: + def __init__(self, attention_store, place_in_unet, pnp, editing_prompts): + self.attnstore = attention_store + self.place_in_unet = place_in_unet + self.editing_prompts = editing_prompts + self.pnp = pnp + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states, + attention_mask=None, + temb=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + self.attnstore( + attention_probs, + is_cross=True, + place_in_unet=self.place_in_unet, + editing_prompts=self.editing_prompts, + PnP=self.pnp, + ) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LEditsPPPipelineStableDiffusion( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + """ + Pipeline for textual image editing using LEDits++ with Stable Diffusion. + + This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass + documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular + device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer ([`~transformers.CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will + automatically be set to [`DPMSolverMultistepScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler): + scheduler = DPMSolverMultistepScheduler.from_config( + scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2 + ) + logger.warning( + "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. " + "The scheduler has been changed to DPMSolverMultistepScheduler." + ) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.inversion_steps = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, eta, generator=None): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + negative_prompt=None, + editing_prompt_embeddings=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if editing_prompt_embeddings is not None and negative_prompt_embeds is not None: + if editing_prompt_embeddings.shape != negative_prompt_embeds.shape: + raise ValueError( + "`editing_prompt_embeddings` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `editing_prompt_embeddings` {editing_prompt_embeddings.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents): + # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + # if latents.shape != shape: + # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_unet(self, attention_store, PnP: bool = False): + attn_procs = {} + for name in self.unet.attn_processors.keys(): + if name.startswith("mid_block"): + place_in_unet = "mid" + elif name.startswith("up_blocks"): + place_in_unet = "up" + elif name.startswith("down_blocks"): + place_in_unet = "down" + else: + continue + + if "attn2" in name and place_in_unet != "mid": + attn_procs[name] = LEDITSCrossAttnProcessor( + attention_store=attention_store, + place_in_unet=place_in_unet, + pnp=PnP, + editing_prompts=self.enabled_editing_prompts, + ) + else: + attn_procs[name] = AttnProcessor() + + self.unet.set_attn_processor(attn_procs) + + def encode_prompt( + self, + device, + num_images_per_prompt, + enable_edit_guidance, + negative_prompt=None, + editing_prompt=None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + editing_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + enable_edit_guidance (`bool`): + whether to perform any editing or reconstruct the input image instead + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + editing_prompt (`str` or `List[str]`, *optional*): + Editing prompt(s) to be encoded. If not defined, one has to pass `editing_prompt_embeds` instead. + editing_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + batch_size = self.batch_size + num_edit_tokens = None + + if negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but exoected" + f"{batch_size} based on the input images. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = negative_prompt_embeds.dtype + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if enable_edit_guidance: + if editing_prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + # if isinstance(self, TextualInversionLoaderMixin): + # prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + + max_length = negative_prompt_embeds.shape[1] + text_inputs = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + return_length=True, + ) + + num_edit_tokens = text_inputs.length - 2 # not counting startoftext and endoftext + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="longest", + return_tensors="pt", + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + editing_prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + editing_prompt_embeds = editing_prompt_embeds[0] + else: + editing_prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + editing_prompt_embeds = editing_prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + editing_prompt_embeds = self.text_encoder.text_model.final_layer_norm(editing_prompt_embeds) + + editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device) + + bs_embed_edit, seq_len, _ = editing_prompt_embeds.shape + editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device) + editing_prompt_embeds = editing_prompt_embeds.repeat(1, num_images_per_prompt, 1) + editing_prompt_embeds = editing_prompt_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return editing_prompt_embeds, negative_prompt_embeds, num_edit_tokens + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, + edit_warmup_steps: Optional[Union[int, List[int]]] = 0, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, + user_mask: Optional[torch.Tensor] = None, + sem_guidance: Optional[List[torch.Tensor]] = None, + use_cross_attn_mask: bool = False, + use_intersect_mask: bool = True, + attn_store_steps: Optional[List[int]] = [], + store_averaged_over_steps: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for editing. The + [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusion.invert`] method has to be called beforehand. Edits will + always be performed for the last inverted image(s). + + Args: + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] instead of a plain + tuple. + editing_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. The image is reconstructed by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + editing_prompt_embeds (`torch.Tensor>`, *optional*): + Pre-computed embeddings to use for guiding the image generation. Guidance direction of embedding should + be specified via `reverse_editing_direction`. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): + Guidance scale for guiding the image generation. If provided as list values should correspond to + `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++ + Paper](https://arxiv.org/abs/2301.12247). + edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which guidance will not be applied. + edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which guidance will no longer be applied. + edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): + Masking threshold of guidance. Threshold should be proportional to the image region that is modified. + 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ + Paper](https://arxiv.org/abs/2301.12247). + user_mask (`torch.Tensor`, *optional*): + User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s + implicit masks do not meet user preferences. + sem_guidance (`List[torch.Tensor]`, *optional*): + List of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + use_cross_attn_mask (`bool`, defaults to `False`): + Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask + is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++ + paper](https://arxiv.org/pdf/2311.16711.pdf). + use_intersect_mask (`bool`, defaults to `True`): + Whether the masking term is calculated as intersection of cross-attention masks and masks derived from + the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate + are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf). + attn_store_steps (`List[int]`, *optional*): + Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes. + store_averaged_over_steps (`bool`, defaults to `True`): + Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If + False, attention maps for each step are stores separately. Just for visualization purposes. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + content, according to the `safety_checker`. + """ + + if self.inversion_steps is None: + raise ValueError( + "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)." + ) + + eta = self.eta + num_images_per_prompt = 1 + latents = self.init_latents + + zs = self.zs + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + + if use_intersect_mask: + use_cross_attn_mask = True + + if use_cross_attn_mask: + self.smoothing = LeditsGaussianSmoothing(self.device) + + if user_mask is not None: + user_mask = user_mask.to(self.device) + + org_prompt = "" + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + negative_prompt, + editing_prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + batch_size = self.batch_size + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + self.enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeds is not None: + enable_edit_guidance = True + self.enabled_editing_prompts = editing_prompt_embeds.shape[0] + else: + self.enabled_editing_prompts = 0 + enable_edit_guidance = False + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + edit_concepts, uncond_embeddings, num_edit_tokens = self.encode_prompt( + editing_prompt=editing_prompt, + device=self.device, + num_images_per_prompt=num_images_per_prompt, + enable_edit_guidance=enable_edit_guidance, + negative_prompt=negative_prompt, + editing_prompt_embeds=editing_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if enable_edit_guidance: + text_embeddings = torch.cat([uncond_embeddings, edit_concepts]) + self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt + else: + text_embeddings = torch.cat([uncond_embeddings]) + + # 4. Prepare timesteps + # self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.inversion_steps + t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0] :])} + + if use_cross_attn_mask: + self.attention_store = LeditsAttentionStore( + average=store_averaged_over_steps, + batch_size=batch_size, + max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0), + max_resolution=None, + ) + self.prepare_unet(self.attention_store, PnP=False) + resolution = latents.shape[-2:] + att_res = (int(resolution[0] / 4), int(resolution[1] / 4)) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + None, + None, + text_embeddings.dtype, + self.device, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(eta) + + self.sem_guidance = None + self.activation_mask = None + + # 7. Denoising loop + num_warmup_steps = 0 + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + + if enable_edit_guidance: + latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts)) + else: + latent_model_input = latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + text_embed_input = text_embeddings + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample + + noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond = noise_pred_out[0] + noise_pred_edit_concepts = noise_pred_out[1:] + + noise_guidance_edit = torch.zeros( + noise_pred_uncond.shape, + device=self.device, + dtype=noise_pred_uncond.dtype, + ) + + if sem_guidance is not None and len(sem_guidance) > i: + noise_guidance_edit += sem_guidance[i].to(self.device) + + elif enable_edit_guidance: + if self.activation_mask is None: + self.activation_mask = torch.zeros( + (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) + ) + + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) + + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + if i < edit_warmup_steps_c: + continue + + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + + if i >= edit_cooldown_steps_c: + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + if user_mask is not None: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask + + if use_cross_attn_mask: + out = self.attention_store.aggregate_attention( + attention_maps=self.attention_store.step_store, + prompts=self.text_cross_attention_maps, + res=att_res, + from_where=["up", "down"], + is_cross=True, + select=self.text_cross_attention_maps.index(editing_prompt[c]), + ) + attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext + + # average over all tokens + if attn_map.shape[3] != num_edit_tokens[c]: + raise ValueError( + f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!" + ) + + attn_map = torch.sum(attn_map, dim=3) + + # gaussian_smoothing + attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect") + attn_map = self.smoothing(attn_map).squeeze(1) + + # torch.quantile function expects float32 + if attn_map.dtype == torch.float32: + tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1) + else: + tmp = torch.quantile( + attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1 + ).to(attn_map.dtype) + attn_mask = torch.where( + attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0 + ) + + # resolution must match latent space dimension + attn_mask = F.interpolate( + attn_mask.unsqueeze(1), + noise_guidance_edit_tmp.shape[-2:], # 64,64 + ).repeat(1, 4, 1, 1) + self.activation_mask[i, c] = attn_mask.detach().cpu() + if not use_intersect_mask: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask + + if use_intersect_mask: + if t <= 800: + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat( + 1, self.unet.config.in_channels, 1, 1 + ) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + intersect_mask = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + * attn_mask + ) + + self.activation_mask[i, c] = intersect_mask.detach().cpu() + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask + + else: + # print(f"only attention mask for step {i}") + noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask + + elif not use_cross_attn_mask: + # calculate quantile + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + self.activation_mask[i, c] = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + .detach() + .cpu() + ) + + noise_guidance_edit_tmp = torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + + noise_guidance_edit += noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + noise_pred = noise_pred_uncond + noise_guidance_edit + + if enable_edit_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_edit_concepts.mean(dim=0, keepdim=False), + guidance_rescale=self.guidance_rescale, + ) + + idx = t_to_idx[int(t)] + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs + ).prev_sample + + # step callback + if use_cross_attn_mask: + store_step = i in attn_store_steps + self.attention_store.between_steps(store_step) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @torch.no_grad() + def invert( + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale: float = 3.5, + num_inversion_steps: int = 30, + skip: float = 0.15, + generator: Optional[torch.Generator] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + resize_mode: Optional[str] = "default", + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ): + r""" + The function to the pipeline for image inversion as described by the [LEDITS++ + Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the + inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead. + + Args: + image (`PipelineImageInput`): + Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect + ratio. + source_prompt (`str`, defaults to `""`): + Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled + if the `source_prompt` is `""`. + source_guidance_scale (`float`, defaults to `3.5`): + Strength of guidance during inversion. + num_inversion_steps (`int`, defaults to `30`): + Number of total performed inversion steps after discarding the initial `skip` steps. + skip (`float`, defaults to `0.15`): + Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values + will lead to stronger changes to the input image. `skip` has to be between `0` and `1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion + deterministic. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) + and respective VAE reconstruction(s). + """ + # Reset attn processor, we do not want to store attn maps during inversion + self.unet.set_attn_processor(AttnProcessor()) + + self.eta = 1.0 + + self.scheduler.config.timestep_spacing = "leading" + self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip))) + self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:] + timesteps = self.inversion_steps + + # 1. encode image + x0, resized = self.encode_image( + image, + dtype=self.text_encoder.dtype, + height=height, + width=width, + resize_mode=resize_mode, + crops_coords=crops_coords, + ) + self.batch_size = x0.shape[0] + + # autoencoder reconstruction + image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] + image_rec = self.image_processor.postprocess(image_rec, output_type="pil") + + # 2. get embeddings + do_classifier_free_guidance = source_guidance_scale > 1.0 + + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + + uncond_embedding, text_embeddings, _ = self.encode_prompt( + num_images_per_prompt=1, + device=self.device, + negative_prompt=None, + enable_edit_guidance=do_classifier_free_guidance, + editing_prompt=source_prompt, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + # 3. find zs and xts + variance_noise_shape = (num_inversion_steps, *x0.shape) + + # intermediate latents + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) + + for t in reversed(timesteps): + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype) + xts[idx] = self.scheduler.add_noise(x0, noise, torch.Tensor([t])) + xts = torch.cat([x0.unsqueeze(0), xts], dim=0) + + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + # noise maps + zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) + + with self.progress_bar(total=len(timesteps)) as progress_bar: + for t in timesteps: + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + # 1. predict noise residual + xt = xts[idx + 1] + + noise_pred = self.unet(xt, timestep=t, encoder_hidden_states=uncond_embedding).sample + + if not source_prompt == "": + noise_pred_cond = self.unet(xt, timestep=t, encoder_hidden_states=text_embeddings).sample + noise_pred = noise_pred + source_guidance_scale * (noise_pred_cond - noise_pred) + + xtm1 = xts[idx] + z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta) + zs[idx] = z + + # correction to avoid error accumulation + xts[idx] = xtm1_corrected + + progress_bar.update() + + self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1) + zs = zs.flip(0) + self.zs = zs + + return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec) + + @torch.no_grad() + def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): + image = self.image_processor.preprocess( + image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + resized = self.image_processor.postprocess(image=image, output_type="pil") + + if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: + logger.warning( + "Your input images far exceed the default resolution of the underlying diffusion model. " + "The output images may contain severe artifacts! " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + image = image.to(dtype) + + x0 = self.vae.encode(image.to(self.device)).latent_dist.mode() + x0 = x0.to(dtype) + x0 = self.vae.config.scaling_factor * x0 + return x0, resized + + +def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + # modifed so that updated xtm1 is returned as well (to avoid error accumulation) + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + if variance > 0.0: + noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) + else: + noise = torch.tensor([0.0]).to(latents.device) + + return noise, mu_xt + (eta * variance**0.5) * noise + + +def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): + def first_order_update(model_output, sample): # timestep, prev_timestep, sample): + sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index] + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + + mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + mu_xt = scheduler.dpm_solver_first_order_update( + model_output=model_output, sample=sample, noise=torch.zeros_like(sample) + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + return noise, prev_sample + + def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample): + sigma_t, sigma_s0, sigma_s1 = ( + scheduler.sigmas[scheduler.step_index + 1], + scheduler.sigmas[scheduler.step_index], + scheduler.sigmas[scheduler.step_index - 1], + ) + + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + mu_xt = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + + return noise, prev_sample + + if scheduler.step_index is None: + scheduler._init_step_index(timestep) + + model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents) + for i in range(scheduler.config.solver_order - 1): + scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] + scheduler.model_outputs[-1] = model_output + + if scheduler.lower_order_nums < 1: + noise, prev_sample = first_order_update(model_output, latents) + else: + noise, prev_sample = second_order_update(scheduler.model_outputs, latents) + + if scheduler.lower_order_nums < scheduler.config.solver_order: + scheduler.lower_order_nums += 1 + + # upon completion increase step index by one + scheduler._step_index += 1 + + return noise, prev_sample + + +def compute_noise(scheduler, *args): + if isinstance(scheduler, DDIMScheduler): + return compute_noise_ddim(scheduler, *args) + elif ( + isinstance(scheduler, DPMSolverMultistepScheduler) + and scheduler.config.algorithm_type == "sde-dpmsolver++" + and scheduler.config.solver_order == 2 + ): + return compute_noise_sde_dpm_pp_2nd(scheduler, *args) + else: + raise NotImplementedError diff --git a/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py new file mode 100755 index 0000000..995cc15 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -0,0 +1,1794 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + Attention, + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import PIL + >>> import requests + >>> from io import BytesIO + + >>> from diffusers import LEditsPPPipelineStableDiffusionXL + + >>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" + >>> image = download_image(img_url) + + >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2) + + >>> edited_image = pipe( + ... editing_prompt=["tennis ball", "tomato"], + ... reverse_editing_direction=[True, False], + ... edit_guidance_scale=[5.0, 10.0], + ... edit_threshold=[0.9, 0.85], + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsAttentionStore +class LeditsAttentionStore: + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False): + # attn.shape = batch_size * head_size, seq_len query, seq_len_key + if attn.shape[1] <= self.max_size: + bs = 1 + int(PnP) + editing_prompts + skip = 2 if PnP else 1 # skip PnP & unconditional + attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3) + source_batch_size = int(attn.shape[1] // bs) + self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet) + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + + self.step_store[key].append(attn) + + def between_steps(self, store_step=True): + if store_step: + if self.average: + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + else: + if len(self.attention_store) == 0: + self.attention_store = [self.step_store] + else: + self.attention_store.append(self.step_store) + + self.cur_step += 1 + self.step_store = self.get_empty_store() + + def get_attention(self, step: int): + if self.average: + attention = { + key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store + } + else: + assert step is not None + attention = self.attention_store[step] + return attention + + def aggregate_attention( + self, attention_maps, prompts, res: Union[int, Tuple[int]], from_where: List[str], is_cross: bool, select: int + ): + out = [[] for x in range(self.batch_size)] + if isinstance(res, int): + num_pixels = res**2 + resolution = (res, res) + else: + num_pixels = res[0] * res[1] + resolution = res[:2] + + for location in from_where: + for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + for batch, item in enumerate(bs_item): + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select] + out[batch].append(cross_maps) + + out = torch.stack([torch.cat(x, dim=0) for x in out]) + # average over heads + out = out.sum(1) / out.shape[1] + return out + + def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None): + self.step_store = self.get_empty_store() + self.attention_store = [] + self.cur_step = 0 + self.average = average + self.batch_size = batch_size + if max_size is None: + self.max_size = max_resolution**2 + elif max_size is not None and max_resolution is None: + self.max_size = max_size + else: + raise ValueError("Only allowed to set one of max_resolution or max_size") + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsGaussianSmoothing +class LeditsGaussianSmoothing: + def __init__(self, device): + kernel_size = [3, 3] + sigma = [0.5, 0.5] + + # The gaussian kernel is the product of the gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) + + self.weight = kernel.to(device) + + def __call__(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return F.conv2d(input, weight=self.weight.to(input.dtype)) + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEDITSCrossAttnProcessor +class LEDITSCrossAttnProcessor: + def __init__(self, attention_store, place_in_unet, pnp, editing_prompts): + self.attnstore = attention_store + self.place_in_unet = place_in_unet + self.editing_prompts = editing_prompts + self.pnp = pnp + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states, + attention_mask=None, + temb=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + self.attnstore( + attention_probs, + is_cross=True, + place_in_unet=self.place_in_unet, + editing_prompts=self.editing_prompts, + PnP=self.pnp, + ) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class LEditsPPPipelineStableDiffusionXL( + DiffusionPipeline, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + """ + Pipeline for textual image editing using LEDits++ with Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionXLPipeline`]. Check the + superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a + particular device, etc.). + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`LEditsPPPipelineStableDiffusionXL.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer ([`~transformers.CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will + automatically be set to [`DPMSolverMultistepScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DPMSolverMultistepScheduler, DDIMScheduler], + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler): + self.scheduler = DPMSolverMultistepScheduler.from_config( + scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2 + ) + logger.warning( + "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. " + "The scheduler has been changed to DPMSolverMultistepScheduler." + ) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + self.inversion_steps = None + + def encode_prompt( + self, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + enable_edit_guidance: bool = True, + editing_prompt: Optional[str] = None, + editing_prompt_embeds: Optional[torch.Tensor] = None, + editing_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ) -> object: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + enable_edit_guidance (`bool`): + Whether to guide towards an editing prompt or not. + editing_prompt (`str` or `List[str]`, *optional*): + Editing prompt(s) to be encoded. If not defined and 'enable_edit_guidance' is True, one has to pass + `editing_prompt_embeds` instead. + editing_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided and 'enable_edit_guidance' is True, editing_prompt_embeds will be generated from + `editing_prompt` input argument. + editing_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated edit pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled editing_pooled_prompt_embeds will be generated from `editing_prompt` + input argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + batch_size = self.batch_size + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + num_edit_tokens = 0 + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but image inversion " + f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of the input images." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(negative_prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(negative_pooled_prompt_embeds) + + if enable_edit_guidance and editing_prompt_embeds is None: + editing_prompt_2 = editing_prompt + + editing_prompts = [editing_prompt, editing_prompt_2] + edit_prompt_embeds_list = [] + + for editing_prompt, tokenizer, text_encoder in zip(editing_prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + editing_prompt = self.maybe_convert_prompt(editing_prompt, tokenizer) + + max_length = negative_prompt_embeds.shape[1] + edit_concepts_input = tokenizer( + # [x for item in editing_prompt for x in repeat(item, batch_size)], + editing_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + return_length=True, + ) + num_edit_tokens = edit_concepts_input.length - 2 + + edit_concepts_embeds = text_encoder( + edit_concepts_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + editing_pooled_prompt_embeds = edit_concepts_embeds[0] + if clip_skip is None: + edit_concepts_embeds = edit_concepts_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)] + + edit_prompt_embeds_list.append(edit_concepts_embeds) + + edit_concepts_embeds = torch.concat(edit_prompt_embeds_list, dim=-1) + elif not enable_edit_guidance: + edit_concepts_embeds = None + editing_pooled_prompt_embeds = None + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = negative_prompt_embeds.shape + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if enable_edit_guidance: + bs_embed_edit, seq_len, _ = edit_concepts_embeds.shape + edit_concepts_embeds = edit_concepts_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + edit_concepts_embeds = edit_concepts_embeds.repeat(1, num_images_per_prompt, 1) + edit_concepts_embeds = edit_concepts_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1) + + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if enable_edit_guidance: + editing_pooled_prompt_embeds = editing_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed_edit * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return ( + negative_prompt_embeds, + edit_concepts_embeds, + negative_pooled_prompt_embeds, + editing_pooled_prompt_embeds, + num_edit_tokens, + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, eta, generator=None): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, device, latents): + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet + def prepare_unet(self, attention_store, PnP: bool = False): + attn_procs = {} + for name in self.unet.attn_processors.keys(): + if name.startswith("mid_block"): + place_in_unet = "mid" + elif name.startswith("up_blocks"): + place_in_unet = "up" + elif name.startswith("down_blocks"): + place_in_unet = "down" + else: + continue + + if "attn2" in name and place_in_unet != "mid": + attn_procs[name] = LEDITSCrossAttnProcessor( + attention_store=attention_store, + place_in_unet=place_in_unet, + pnp=PnP, + editing_prompts=self.enabled_editing_prompts, + ) + else: + attn_procs[name] = AttnProcessor() + + self.unet.set_attn_processor(attn_procs) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + denoising_end: Optional[float] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_embeddings: Optional[torch.Tensor] = None, + editing_pooled_prompt_embeds: Optional[torch.Tensor] = None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, + edit_warmup_steps: Optional[Union[int, List[int]]] = 0, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, + sem_guidance: Optional[List[torch.Tensor]] = None, + use_cross_attn_mask: bool = False, + use_intersect_mask: bool = False, + user_mask: Optional[torch.Tensor] = None, + attn_store_steps: Optional[List[int]] = [], + store_averaged_over_steps: bool = True, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for editing. The + [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL.invert`] method has to be called beforehand. Edits + will always be performed for the last inverted image(s). + + Args: + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + editing_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. The image is reconstructed by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + editing_prompt_embeddings (`torch.Tensor`, *optional*): + Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input argument. + editing_pooled_prompt_embeddings (`torch.Tensor`, *optional*): + Pre-generated pooled edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input + argument. + reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): + Guidance scale for guiding the image generation. If provided as list values should correspond to + `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++ + Paper](https://arxiv.org/abs/2301.12247). + edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which guidance is not applied. + edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which guidance is no longer applied. + edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): + Masking threshold of guidance. Threshold should be proportional to the image region that is modified. + 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ + Paper](https://arxiv.org/abs/2301.12247). + sem_guidance (`List[torch.Tensor]`, *optional*): + List of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + use_cross_attn_mask: + Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask + is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++ + paper](https://arxiv.org/pdf/2311.16711.pdf). + use_intersect_mask: + Whether the masking term is calculated as intersection of cross-attention masks and masks derived from + the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate + are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf). + user_mask: + User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s + implicit masks do not meet user preferences. + attn_store_steps: + Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes. + store_averaged_over_steps: + Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If + False, attention maps for each step are stores separately. Just for visualization purposes. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + if self.inversion_steps is None: + raise ValueError( + "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)." + ) + + eta = self.eta + num_images_per_prompt = 1 + latents = self.init_latents + + zs = self.zs + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + + if use_intersect_mask: + use_cross_attn_mask = True + + if use_cross_attn_mask: + self.smoothing = LeditsGaussianSmoothing(self.device) + + if user_mask is not None: + user_mask = user_mask.to(self.device) + + # TODO: Check inputs + # 1. Check inputs. Raise error if not correct + # self.check_inputs( + # callback_steps, + # negative_prompt, + # negative_prompt_2, + # prompt_embeds, + # negative_prompt_embeds, + # pooled_prompt_embeds, + # negative_pooled_prompt_embeds, + # ) + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + + # 2. Define call parameters + batch_size = self.batch_size + + device = self._execution_device + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + self.enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeddings is not None: + enable_edit_guidance = True + self.enabled_editing_prompts = editing_prompt_embeddings.shape[0] + else: + self.enabled_editing_prompts = 0 + enable_edit_guidance = False + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + edit_prompt_embeds, + negative_pooled_prompt_embeds, + pooled_edit_embeds, + num_edit_tokens, + ) = self.encode_prompt( + device=device, + num_images_per_prompt=num_images_per_prompt, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + enable_edit_guidance=enable_edit_guidance, + editing_prompt=editing_prompt, + editing_prompt_embeds=editing_prompt_embeddings, + editing_pooled_prompt_embeds=editing_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + # self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.inversion_steps + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + + if use_cross_attn_mask: + self.attention_store = LeditsAttentionStore( + average=store_averaged_over_steps, + batch_size=batch_size, + max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0), + max_resolution=None, + ) + self.prepare_unet(self.attention_store) + resolution = latents.shape[-2:] + att_res = (int(resolution[0] / 4), int(resolution[1] / 4)) + + # 5. Prepare latent variables + latents = self.prepare_latents(device=device, latents=latents) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(eta) + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + # 7. Prepare added time ids & embeddings + add_text_embeds = negative_pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + self.size, + crops_coords_top_left, + self.size, + dtype=negative_pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if enable_edit_guidance: + prompt_embeds = torch.cat([prompt_embeds, edit_prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, pooled_edit_embeds], dim=0) + edit_concepts_time_ids = add_time_ids.repeat(edit_prompt_embeds.shape[0], 1) + add_time_ids = torch.cat([add_time_ids, edit_concepts_time_ids], dim=0) + self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None: + # TODO: fix image encoding + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + + # 8. Denoising loop + self.sem_guidance = None + self.activation_mask = None + + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond = noise_pred_out[0] + noise_pred_edit_concepts = noise_pred_out[1:] + + noise_guidance_edit = torch.zeros( + noise_pred_uncond.shape, + device=self.device, + dtype=noise_pred_uncond.dtype, + ) + + if sem_guidance is not None and len(sem_guidance) > i: + noise_guidance_edit += sem_guidance[i].to(self.device) + + elif enable_edit_guidance: + if self.activation_mask is None: + self.activation_mask = torch.zeros( + (len(timesteps), self.enabled_editing_prompts, *noise_pred_edit_concepts[0].shape) + ) + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) + + # noise_guidance_edit = torch.zeros_like(noise_guidance) + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + if i < edit_warmup_steps_c: + continue + + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + + if i >= edit_cooldown_steps_c: + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + if user_mask is not None: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask + + if use_cross_attn_mask: + out = self.attention_store.aggregate_attention( + attention_maps=self.attention_store.step_store, + prompts=self.text_cross_attention_maps, + res=att_res, + from_where=["up", "down"], + is_cross=True, + select=self.text_cross_attention_maps.index(editing_prompt[c]), + ) + attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext + + # average over all tokens + if attn_map.shape[3] != num_edit_tokens[c]: + raise ValueError( + f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!" + ) + attn_map = torch.sum(attn_map, dim=3) + + # gaussian_smoothing + attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect") + attn_map = self.smoothing(attn_map).squeeze(1) + + # torch.quantile function expects float32 + if attn_map.dtype == torch.float32: + tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1) + else: + tmp = torch.quantile( + attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1 + ).to(attn_map.dtype) + attn_mask = torch.where( + attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0 + ) + + # resolution must match latent space dimension + attn_mask = F.interpolate( + attn_mask.unsqueeze(1), + noise_guidance_edit_tmp.shape[-2:], # 64,64 + ).repeat(1, 4, 1, 1) + self.activation_mask[i, c] = attn_mask.detach().cpu() + if not use_intersect_mask: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask + + if use_intersect_mask: + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat( + 1, self.unet.config.in_channels, 1, 1 + ) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + intersect_mask = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + * attn_mask + ) + + self.activation_mask[i, c] = intersect_mask.detach().cpu() + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask + + elif not use_cross_attn_mask: + # calculate quantile + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + self.activation_mask[i, c] = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + .detach() + .cpu() + ) + + noise_guidance_edit_tmp = torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + + noise_guidance_edit += noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + noise_pred = noise_pred_uncond + noise_guidance_edit + + # compute the previous noisy sample x_t -> x_t-1 + if enable_edit_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_edit_concepts.mean(dim=0, keepdim=False), + guidance_rescale=self.guidance_rescale, + ) + + idx = t_to_idx[int(t)] + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs, return_dict=False + )[0] + + # step callback + if use_cross_attn_mask: + store_step = i in attn_store_steps + self.attention_store.between_steps(store_step) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + # negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=None) + + @torch.no_grad() + # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image + def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): + image = self.image_processor.preprocess( + image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + resized = self.image_processor.postprocess(image=image, output_type="pil") + + if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: + logger.warning( + "Your input images far exceed the default resolution of the underlying diffusion model. " + "The output images may contain severe artifacts! " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + image = image.to(self.device, dtype=dtype) + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + image = image.float() + self.upcast_vae() + + x0 = self.vae.encode(image).latent_dist.mode() + x0 = x0.to(dtype) + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + x0 = self.vae.config.scaling_factor * x0 + return x0, resized + + @torch.no_grad() + def invert( + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale=3.5, + negative_prompt: str = None, + negative_prompt_2: str = None, + num_inversion_steps: int = 50, + skip: float = 0.15, + generator: Optional[torch.Generator] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + num_zero_noise_steps: int = 3, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + The function to the pipeline for image inversion as described by the [LEDITS++ + Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the + inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead. + + Args: + image (`PipelineImageInput`): + Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect + ratio. + source_prompt (`str`, defaults to `""`): + Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled + if the `source_prompt` is `""`. + source_guidance_scale (`float`, defaults to `3.5`): + Strength of guidance during inversion. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_inversion_steps (`int`, defaults to `50`): + Number of total performed inversion steps after discarding the initial `skip` steps. + skip (`float`, defaults to `0.15`): + Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values + will lead to stronger changes to the input image. `skip` has to be between `0` and `1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion + deterministic. + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + num_zero_noise_steps (`int`, defaults to `3`): + Number of final diffusion steps that will not renoise the current image. If no steps are set to zero + SD-XL in combination with [`DPMSolverMultistepScheduler`] will produce noise artifacts. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Returns: + [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) + and respective VAE reconstruction(s). + """ + + # Reset attn processor, we do not want to store attn maps during inversion + self.unet.set_attn_processor(AttnProcessor()) + + self.eta = 1.0 + + self.scheduler.config.timestep_spacing = "leading" + self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip))) + self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:] + timesteps = self.inversion_steps + + num_images_per_prompt = 1 + + device = self._execution_device + + # 0. Ensure that only uncond embedding is used if prompt = "" + if source_prompt == "": + # noise pred should only be noise_pred_uncond + source_guidance_scale = 0.0 + do_classifier_free_guidance = False + else: + do_classifier_free_guidance = source_guidance_scale > 1.0 + + # 1. prepare image + x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype) + width = x0.shape[2] * self.vae_scale_factor + height = x0.shape[3] * self.vae_scale_factor + self.size = (height, width) + + self.batch_size = x0.shape[0] + + # 2. get embeddings + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + if isinstance(source_prompt, str): + source_prompt = [source_prompt] * self.batch_size + + ( + negative_prompt_embeds, + prompt_embeds, + negative_pooled_prompt_embeds, + edit_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + device=device, + num_images_per_prompt=num_images_per_prompt, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + editing_prompt=source_prompt, + lora_scale=text_encoder_lora_scale, + enable_edit_guidance=do_classifier_free_guidance, + ) + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + # 3. Prepare added time ids & embeddings + add_text_embeds = negative_pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + self.size, + crops_coords_top_left, + self.size, + dtype=negative_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, edit_pooled_prompt_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + negative_prompt_embeds = negative_prompt_embeds.to(device) + + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(self.batch_size * num_images_per_prompt, 1) + + # autoencoder reconstruction + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image_rec = self.vae.decode( + x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator + )[0] + elif self.vae.config.force_upcast: + x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image_rec = self.vae.decode( + x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator + )[0] + else: + image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] + + image_rec = self.image_processor.postprocess(image_rec, output_type="pil") + + # 5. find zs and xts + variance_noise_shape = (num_inversion_steps, *x0.shape) + + # intermediate latents + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype) + + for t in reversed(timesteps): + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype) + xts[idx] = self.scheduler.add_noise(x0, noise, t.unsqueeze(0)) + xts = torch.cat([x0.unsqueeze(0), xts], dim=0) + + # noise maps + zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype) + + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + + for t in self.progress_bar(timesteps): + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + # 1. predict noise residual + xt = xts[idx + 1] + + latent_model_input = torch.cat([xt] * 2) if do_classifier_free_guidance else xt + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=negative_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # 2. perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + noise_pred = noise_pred_uncond + source_guidance_scale * (noise_pred_text - noise_pred_uncond) + + xtm1 = xts[idx] + z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta) + zs[idx] = z + + # correction to avoid error accumulation + xts[idx] = xtm1_corrected + + self.init_latents = xts[-1] + zs = zs.flip(0) + + if num_zero_noise_steps > 0: + zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:]) + self.zs = zs + return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec) + + +# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_ddim +def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + # modifed so that updated xtm1 is returned as well (to avoid error accumulation) + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + if variance > 0.0: + noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) + else: + noise = torch.tensor([0.0]).to(latents.device) + + return noise, mu_xt + (eta * variance**0.5) * noise + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_sde_dpm_pp_2nd +def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): + def first_order_update(model_output, sample): # timestep, prev_timestep, sample): + sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index] + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + + mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + mu_xt = scheduler.dpm_solver_first_order_update( + model_output=model_output, sample=sample, noise=torch.zeros_like(sample) + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + return noise, prev_sample + + def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample): + sigma_t, sigma_s0, sigma_s1 = ( + scheduler.sigmas[scheduler.step_index + 1], + scheduler.sigmas[scheduler.step_index], + scheduler.sigmas[scheduler.step_index - 1], + ) + + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + mu_xt = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + + return noise, prev_sample + + if scheduler.step_index is None: + scheduler._init_step_index(timestep) + + model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents) + for i in range(scheduler.config.solver_order - 1): + scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] + scheduler.model_outputs[-1] = model_output + + if scheduler.lower_order_nums < 1: + noise, prev_sample = first_order_update(model_output, latents) + else: + noise, prev_sample = second_order_update(scheduler.model_outputs, latents) + + if scheduler.lower_order_nums < scheduler.config.solver_order: + scheduler.lower_order_nums += 1 + + # upon completion increase step index by one + scheduler._step_index += 1 + + return noise, prev_sample + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise +def compute_noise(scheduler, *args): + if isinstance(scheduler, DDIMScheduler): + return compute_noise_ddim(scheduler, *args) + elif ( + isinstance(scheduler, DPMSolverMultistepScheduler) + and scheduler.config.algorithm_type == "sde-dpmsolver++" + and scheduler.config.solver_order == 2 + ): + return compute_noise_sde_dpm_pp_2nd(scheduler, *args) + else: + raise NotImplementedError diff --git a/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_output.py b/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_output.py new file mode 100755 index 0000000..756be82 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/ledits_pp/pipeline_output.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class LEditsPPDiffusionPipelineOutput(BaseOutput): + """ + Output class for LEdits++ Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`List[bool]`) + List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +@dataclass +class LEditsPPInversionPipelineOutput(BaseOutput): + """ + Output class for LEdits++ Diffusion pipelines. + + Args: + input_images (`List[PIL.Image.Image]` or `np.ndarray`) + List of the cropped and resized input images as PIL images of length `batch_size` or NumPy array of shape ` + (batch_size, height, width, num_channels)`. + vae_reconstruction_images (`List[PIL.Image.Image]` or `np.ndarray`) + List of VAE reconstruction of all input images as PIL images of length `batch_size` or NumPy array of shape + ` (batch_size, height, width, num_channels)`. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + vae_reconstruction_images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/diffusers/src/diffusers/pipelines/lumina/__init__.py b/diffusers/src/diffusers/pipelines/lumina/__init__.py new file mode 100755 index 0000000..ca13963 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/lumina/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lumina import LuminaText2ImgPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py b/diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py new file mode 100755 index 0000000..6bc7e8d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -0,0 +1,897 @@ +# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import List, Optional, Tuple, Union + +import torch +from transformers import AutoModel, AutoTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...models.embeddings import get_2d_rotary_pos_embed_lumina +from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LuminaText2ImgPipeline + + >>> pipe = LuminaText2ImgPipeline.from_pretrained( + ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LuminaText2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Lumina-T2I. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`AutoModel`]): + Frozen text-encoder. Lumina-T2I uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. + tokenizer (`AutoModel`): + Tokenizer of class + [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + transformer: LuminaNextDiT2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: AutoModel, + tokenizer: AutoTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.max_sequence_length = 256 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.default_image_size = self.default_sample_size * self.vae_scale_factor + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clean_caption: Optional[bool] = False, + max_length: Optional[int] = None, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + pad_to_multiple_of=8, + max_length=self.max_sequence_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {self.max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + clean_caption=clean_caption, + ) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + # Padding negative prompt to the same length with prompt + prompt_max_length = prompt_embeds.shape[1] + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=prompt_max_length, + truncation=True, + return_tensors="pt", + ) + negative_text_input_ids = negative_text_inputs.input_ids.to(device) + negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) + # Get the negative prompt embeddings + negative_prompt_embeds = self.text_encoder( + negative_text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ) + + negative_dtype = self.text_encoder.dtype + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + _, seq_len, _ = negative_prompt_embeds.shape + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Union[str, List[str]] = None, + sigmas: List[float] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clean_caption: bool = True, + max_sequence_length: int = 256, + scaling_watershed: Optional[float] = 1.0, + proportional_attn: Optional[bool] = True, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to 120): + Maximum sequence length to use with the `prompt`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + cross_attention_kwargs = {} + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if proportional_attn: + cross_attention_kwargs["base_sequence_length"] = (self.default_image_size // 16) ** 2 + + scaling_factor = math.sqrt(width * height / self.default_image_size**2) + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor( + [current_timestep], + dtype=dtype, + device=latent_model_input.device, + ) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps + + # prepare image_rotary_emb for positional encoding + # dynamic scaling_factor for different resolution. + # NOTE: For `Time-aware` denosing mechanism from Lumina-Next + # https://arxiv.org/abs/2406.18583, Sec 2.3 + # NOTE: We should compute different image_rotary_emb with different timestep. + if current_timestep[0] < scaling_watershed: + linear_factor = scaling_factor + ntk_factor = 1.0 + else: + linear_factor = 1.0 + ntk_factor = scaling_factor + image_rotary_emb = get_2d_rotary_pos_embed_lumina( + self.transformer.head_dim, + 384, + 384, + linear_factor=linear_factor, + ntk_factor=ntk_factor, + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_mask=prompt_attention_mask, + image_rotary_emb=image_rotary_emb, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # perform guidance scale + # NOTE: For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + if do_classifier_free_guidance: + noise_pred_eps, noise_pred_rest = noise_pred[:, :3], noise_pred[:, 3:] + noise_pred_cond_eps, noise_pred_uncond_eps = torch.split( + noise_pred_eps, len(noise_pred_eps) // 2, dim=0 + ) + noise_pred_half = noise_pred_uncond_eps + guidance_scale * ( + noise_pred_cond_eps - noise_pred_uncond_eps + ) + noise_pred_eps = torch.cat([noise_pred_half, noise_pred_half], dim=0) + + noise_pred = torch.cat([noise_pred_eps, noise_pred_rest], dim=1) + noise_pred, _ = noise_pred.chunk(2, dim=0) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + noise_pred = -noise_pred + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + progress_bar.update() + + if not output_type == "latent": + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/marigold/__init__.py b/diffusers/src/diffusers/pipelines/marigold/__init__.py new file mode 100755 index 0000000..b5ae03a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/marigold/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"] + _import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"] + _import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .marigold_image_processing import MarigoldImageProcessor + from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline + from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/marigold/marigold_image_processing.py b/diffusers/src/diffusers/pipelines/marigold/marigold_image_processing.py new file mode 100755 index 0000000..51b9983 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/marigold/marigold_image_processing.py @@ -0,0 +1,576 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from PIL import Image + +from ... import ConfigMixin +from ...configuration_utils import register_to_config +from ...image_processor import PipelineImageInput +from ...utils import CONFIG_NAME, logging +from ...utils.import_utils import is_matplotlib_available + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MarigoldImageProcessor(ConfigMixin): + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + vae_scale_factor: int = 8, + do_normalize: bool = True, + do_range_check: bool = True, + ): + super().__init__() + + @staticmethod + def expand_tensor_or_array(images: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: + """ + Expand a tensor or array to a specified number of images. + """ + if isinstance(images, np.ndarray): + if images.ndim == 2: # [H,W] -> [1,H,W,1] + images = images[None, ..., None] + if images.ndim == 3: # [H,W,C] -> [1,H,W,C] + images = images[None] + elif isinstance(images, torch.Tensor): + if images.ndim == 2: # [H,W] -> [1,1,H,W] + images = images[None, None] + elif images.ndim == 3: # [1,H,W] -> [1,1,H,W] + images = images[None] + else: + raise ValueError(f"Unexpected input type: {type(images)}") + return images + + @staticmethod + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.Tensor: + """ + Convert a NumPy image to a PyTorch tensor. + """ + if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") + if np.issubdtype(images.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") + if np.issubdtype(images.dtype, bool): + raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def resize_antialias( + image: torch.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None + ) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + antialias = is_aa and mode in ("bilinear", "bicubic") + image = F.interpolate(image, size, mode=mode, antialias=antialias) + + return image + + @staticmethod + def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + max_orig = max(h, w) + new_h = h * max_edge_sz // max_orig + new_w = w * max_edge_sz // max_orig + + if new_h == 0 or new_w == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") + + image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) + + return image + + @staticmethod + def pad_image(image: torch.Tensor, align: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + ph, pw = -h % align, -w % align + + image = F.pad(image, (0, pw, 0, ph), mode="replicate") + + return image, (ph, pw) + + @staticmethod + def unpad_image(image: torch.Tensor, padding: Tuple[int, int]) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + ph, pw = padding + uh = None if ph == 0 else -ph + uw = None if pw == 0 else -pw + + image = image[:, :, :uh, :uw] + + return image + + @staticmethod + def load_image_canonical( + image: Union[torch.Tensor, np.ndarray, Image.Image], + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Tuple[torch.Tensor, int]: + if isinstance(image, Image.Image): + image = np.array(image) + + image_dtype_max = None + if isinstance(image, (np.ndarray, torch.Tensor)): + image = MarigoldImageProcessor.expand_tensor_or_array(image) + if image.ndim != 4: + raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") + if isinstance(image, np.ndarray): + if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") + if np.issubdtype(image.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") + if np.issubdtype(image.dtype, bool): + raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") + if np.issubdtype(image.dtype, np.unsignedinteger): + image_dtype_max = np.iinfo(image.dtype).max + image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond torch.uint8 + image = MarigoldImageProcessor.numpy_to_pt(image) + + if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None: + if image.dtype != torch.uint8: + raise ValueError(f"Image dtype={image.dtype} is not supported.") + image_dtype_max = 255 + + if not torch.is_tensor(image): + raise ValueError(f"Input type unsupported: {type(image)}.") + + if image.shape[1] == 1: + image = image.repeat(1, 3, 1, 1) # [N,1,H,W] -> [N,3,H,W] + if image.shape[1] != 3: + raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") + + image = image.to(device=device, dtype=dtype) + + if image_dtype_max is not None: + image = image / image_dtype_max + + return image + + @staticmethod + def check_image_values_range(image: torch.Tensor) -> None: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.min().item() < 0.0 or image.max().item() > 1.0: + raise ValueError("Input image data is partially outside of the [0,1] range.") + + def preprocess( + self, + image: PipelineImageInput, + processing_resolution: Optional[int] = None, + resample_method_input: str = "bilinear", + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ): + if isinstance(image, list): + images = None + for i, img in enumerate(image): + img = self.load_image_canonical(img, device, dtype) # [N,3,H,W] + if images is None: + images = img + else: + if images.shape[2:] != img.shape[2:]: + raise ValueError( + f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " + f"{images.shape[2:]}" + ) + images = torch.cat((images, img), dim=0) + image = images + del images + else: + image = self.load_image_canonical(image, device, dtype) # [N,3,H,W] + + original_resolution = image.shape[2:] + + if self.config.do_range_check: + self.check_image_values_range(image) + + if self.config.do_normalize: + image = image * 2.0 - 1.0 + + if processing_resolution is not None and processing_resolution > 0: + image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW] + + image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW] + + return image, padding, original_resolution + + @staticmethod + def colormap( + image: Union[np.ndarray, torch.Tensor], + cmap: str = "Spectral", + bytes: bool = False, + _force_method: Optional[str] = None, + ) -> Union[np.ndarray, torch.Tensor]: + """ + Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the + behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral", + "binary") without having to install or import matplotlib. For all other cases, the function will attempt to use + the native implementation. + + Args: + image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. + cmap: Colormap name. + bytes: Whether to return the output as uint8 or floating point image. + _force_method: + Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom + implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default). + + Returns: + An RGB-colorized tensor corresponding to the input image. + """ + if not (torch.is_tensor(image) or isinstance(image, np.ndarray)): + raise ValueError("Argument must be a numpy array or torch tensor.") + if _force_method not in (None, "matplotlib", "custom"): + raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") + + supported_cmaps = { + "binary": [ + (1.0, 1.0, 1.0), + (0.0, 0.0, 0.0), + ], + "Spectral": [ # Taken from matplotlib/_cm.py + (0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] + (0.83529411764705885, 0.24313725490196078, 0.30980392156862746), + (0.95686274509803926, 0.42745098039215684, 0.2627450980392157), + (0.99215686274509807, 0.68235294117647061, 0.38039215686274508), + (0.99607843137254903, 0.8784313725490196, 0.54509803921568623), + (1.0, 1.0, 0.74901960784313726), + (0.90196078431372551, 0.96078431372549022, 0.59607843137254901), + (0.6705882352941176, 0.8666666666666667, 0.64313725490196083), + (0.4, 0.76078431372549016, 0.6470588235294118), + (0.19607843137254902, 0.53333333333333333, 0.74117647058823533), + (0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1] + ], + } + + def method_matplotlib(image, cmap, bytes=False): + if is_matplotlib_available(): + import matplotlib + else: + return None + + arg_is_pt, device = torch.is_tensor(image), None + if arg_is_pt: + image, device = image.cpu().numpy(), image.device + + if cmap not in matplotlib.colormaps: + raise ValueError( + f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" + ) + + cmap = matplotlib.colormaps[cmap] + out = cmap(image, bytes=bytes) # [?,4] + out = out[..., :3] # [?,3] + + if arg_is_pt: + out = torch.tensor(out, device=device) + + return out + + def method_custom(image, cmap, bytes=False): + arg_is_np = isinstance(image, np.ndarray) + if arg_is_np: + image = torch.tensor(image) + if image.dtype == torch.uint8: + image = image.float() / 255 + else: + image = image.float() + + is_cmap_reversed = cmap.endswith("_r") + if is_cmap_reversed: + cmap = cmap[:-2] + + if cmap not in supported_cmaps: + raise ValueError( + f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib." + ) + + cmap = supported_cmaps[cmap] + if is_cmap_reversed: + cmap = cmap[::-1] + cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3] + K = cmap.shape[0] + + pos = image.clamp(min=0, max=1) * (K - 1) + left = pos.long() + right = (left + 1).clamp(max=K - 1) + + d = (pos - left.float()).unsqueeze(-1) + left_colors = cmap[left] + right_colors = cmap[right] + + out = (1 - d) * left_colors + d * right_colors + + if bytes: + out = (out * 255).to(torch.uint8) + + if arg_is_np: + out = out.numpy() + + return out + + if _force_method is None and torch.is_tensor(image) and cmap == "Spectral": + return method_custom(image, cmap, bytes) + + out = None + if _force_method != "custom": + out = method_matplotlib(image, cmap, bytes) + + if _force_method == "matplotlib" and out is None: + raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") + + if out is None: + out = method_custom(image, cmap, bytes) + + return out + + @staticmethod + def visualize_depth( + depth: Union[ + PIL.Image.Image, + np.ndarray, + torch.Tensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.Tensor], + ], + val_min: float = 0.0, + val_max: float = 1.0, + color_map: str = "Spectral", + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. + + Args: + depth (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], + List[torch.Tensor]]`): Depth maps. + val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. + val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. + color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel + depth prediction into colored representation. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization. + """ + if val_max <= val_min: + raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") + + def visualize_depth_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if isinstance(img, PIL.Image.Image): + if img.mode != "I;16": + raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") + img = np.array(img).astype(np.float32) / (2**16 - 1) + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if isinstance(img, np.ndarray): + img = torch.from_numpy(img) + if not torch.is_floating_point(img): + raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") + else: + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3] + img = PIL.Image.fromarray(img.cpu().numpy()) + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, torch.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W] + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def export_depth_to_16bit_png( + depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], + val_min: float = 0.0, + val_max: float = 1.0, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + def export_depth_to_16bit_png_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if not isinstance(img, np.ndarray) and not torch.is_tensor(img): + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if torch.is_tensor(img): + img = img.cpu().numpy() + if not np.issubdtype(img.dtype, np.floating): + raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = (img * (2**16 - 1)).astype(np.uint16) + img = PIL.Image.fromarray(img, mode="I;16") + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, torch.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def visualize_normals( + normals: Union[ + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], + ], + flip_x: bool = False, + flip_y: bool = False, + flip_z: bool = False, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. + + Args: + normals (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): + Surface normals. + flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. + Default direction is right. + flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. + Default direction is top. + flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. + Default direction is facing the observer. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization. + """ + flip_vec = None + if any((flip_x, flip_y, flip_z)): + flip_vec = torch.tensor( + [ + (-1) ** flip_x, + (-1) ** flip_y, + (-1) ** flip_z, + ], + dtype=torch.float32, + ) + + def visualize_normals_one(img, idx=None): + img = img.permute(1, 2, 0) + if flip_vec is not None: + img *= flip_vec.to(img.device) + img = (img + 1.0) * 0.5 + img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() + img = PIL.Image.fromarray(img) + return img + + if normals is None or isinstance(normals, list) and any(o is None for o in normals): + raise ValueError("Input normals is `None`") + if isinstance(normals, (np.ndarray, torch.Tensor)): + normals = MarigoldImageProcessor.expand_tensor_or_array(normals) + if isinstance(normals, np.ndarray): + normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W] + if not (normals.ndim == 4 and normals.shape[1] == 3): + raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + elif isinstance(normals, list): + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + else: + raise ValueError(f"Unexpected input type: {type(normals)}") + + @staticmethod + def visualize_uncertainty( + uncertainty: Union[ + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], + ], + saturation_percentile=95, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. + + Args: + uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): + Uncertainty maps. + saturation_percentile (`int`, *optional*, defaults to `95`): + Specifies the percentile uncertainty value visualized with maximum intensity. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization. + """ + + def visualize_uncertainty_one(img, idx=None): + prefix = "Uncertainty" + (f"[{idx}]" if idx else "") + if img.min() < 0: + raise ValueError(f"{prefix}: unexected data range, min={img.min()}.") + img = img.squeeze(0).cpu().numpy() + saturation_value = np.percentile(img, saturation_percentile) + img = np.clip(img * 255 / saturation_value, 0, 255) + img = img.astype(np.uint8) + img = PIL.Image.fromarray(img) + return img + + if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): + raise ValueError("Input uncertainty is `None`") + if isinstance(uncertainty, (np.ndarray, torch.Tensor)): + uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) + if isinstance(uncertainty, np.ndarray): + uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W] + if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1): + raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].") + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + elif isinstance(uncertainty, list): + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + else: + raise ValueError(f"Unexpected input type: {type(uncertainty)}") diff --git a/diffusers/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/diffusers/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py new file mode 100755 index 0000000..a602ba6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py @@ -0,0 +1,813 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput +from ...models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + LCMScheduler, +) +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.import_utils import is_scipy_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained( +... "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> depth = pipe(image) + +>>> vis = pipe.image_processor.visualize_depth(depth.prediction) +>>> vis[0].save("einstein_depth.png") + +>>> depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction) +>>> depth_16bit[0].save("einstein_depth_16bit.png") +``` +""" + + +@dataclass +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold monocular depth prediction pipeline. + + Args: + prediction (`np.ndarray`, `torch.Tensor`): + Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height + \times width$, regardless of whether the images were passed as a 4D array or a list. + uncertainty (`None`, `np.ndarray`, `torch.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages + \times 1 \times height \times width$. + latent (`None`, `torch.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. + """ + + prediction: Union[np.ndarray, torch.Tensor] + uncertainty: Union[None, np.ndarray, torch.Tensor] + latent: Union[None, torch.Tensor] + + +class MarigoldDepthPipeline(DiffusionPipeline): + """ + Pipeline for monocular depth estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the depth latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + scale_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in + the model config. When used together with the `shift_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + shift_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in + the model config. When used together with the `scale_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("depth", "disparity") + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: Optional[str] = None, + scale_invariant: Optional[bool] = True, + shift_invariant: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + prediction_type=prediction_type, + scale_invariant=scale_invariant, + shift_invariant=shift_invariant, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.scale_invariant = scale_invariant + self.shift_invariant = shift_invariant + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: Optional[Dict[str, Any]], + latents: Optional[torch.Tensor], + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + output_type: str, + output_uncertainty: bool, + ) -> int: + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size > 1 and (self.scale_invariant or self.shift_invariant) and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use ensembling.") + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("mean", "median"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'mean'` or `'median'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not torch.is_tensor(latents): + raise ValueError("`latents` must be a torch.Tensor.") + if latents.dim() != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + if not all(g.device.type == generator[0].device.type for g in generator): + raise ValueError("`generator` device placement is not consistent in the list.") + elif not isinstance(generator, torch.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_resolution: Optional[int] = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: Optional[Dict[str, Any]] = None, + latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), + `List[torch.Tensor]`: An input image or images used as an input for the depth estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 + for Marigold-LCM models. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for + faster inference. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in + every pixel location, can be either `"median"` or `"mean"`. + - regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that + pulls the aligned predictions to the unit range from 0 to 1. + - max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to + `scipy.optimize.minimize` function, `options` argument. + - tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the + tolerance is reached. + - max_res (`int`, *optional*, defaults to `None`): Resolution at which the alignment is performed; + `None` matches the `processing_resolution`. + latents (`torch.Tensor`, or `List[torch.Tensor]`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldDepthOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldDepthOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + device = self._execution_device + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, device, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken + # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled + # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( + batch_size, 1, 1 + ) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step( + noise, t, batch_pred_latent, generator=generator + ).prev_sample # [B,4,h,w] + + pred_latents.append(batch_pred_latent) + + pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = torch.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + dim=0, + ) # [N*E,1,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,1,PH,PW] + prediction = [ + self.ensemble_depth( + prediction[i], + self.scale_invariant, + self.shift_invariant, + output_uncertainty, + **(ensembling_kwargs or {}), + ) + for i in range(num_images) + ] # [ [[1,1,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,1,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = torch.cat(prediction, dim=0) # [N,1,PH,PW] + if output_uncertainty: + uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldDepthOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + def prepare_latents( + self, + image: torch.Tensor, + latents: Optional[torch.Tensor], + generator: Optional[torch.Generator], + ensemble_size: int, + batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + def retrieve_latents(encoder_output): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + image_latent = torch.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])) + for i in range(0, image.shape[0], batch_size) + ], + dim=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + device=image_latent.device, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: + if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = prediction.mean(dim=1, keepdim=True) # [B,1,H,W] + prediction = torch.clip(prediction, -1.0, 1.0) # [B,1,H,W] + prediction = (prediction + 1.0) / 2.0 + + return prediction # [B,1,H,W] + + @staticmethod + def ensemble_depth( + depth: torch.Tensor, + scale_invariant: bool = True, + shift_invariant: bool = True, + output_uncertainty: bool = False, + reduction: str = "median", + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + max_res: int = 1024, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Ensembles the depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the + number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for + depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The + alignment happens when the predictions have one or more degrees of freedom, that is when they are either + affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only + `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) + alignment is skipped and only ensembling is performed. + + Args: + depth (`torch.Tensor`): + Input ensemble depth maps. + scale_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as scale-invariant. + shift_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as shift-invariant. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"median"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and + `"median"`. + regularizer_strength (`float`, *optional*, defaults to `0.02`): + Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. + max_iter (`int`, *optional*, defaults to `2`): + Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` + argument. + tol (`float`, *optional*, defaults to `1e-3`): + Alignment solver tolerance. The solver stops when the tolerance is reached. + max_res (`int`, *optional*, defaults to `1024`): + Resolution at which the alignment is performed; `None` matches the `processing_resolution`. + Returns: + A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: + `(1, 1, H, W)`. + """ + if depth.dim() != 4 or depth.shape[1] != 1: + raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") + if reduction not in ("mean", "median"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + if not scale_invariant and shift_invariant: + raise ValueError("Pure shift-invariant ensembling is not supported.") + + def init_param(depth: torch.Tensor): + init_min = depth.reshape(ensemble_size, -1).min(dim=1).values + init_max = depth.reshape(ensemble_size, -1).max(dim=1).values + + if scale_invariant and shift_invariant: + init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) + init_t = -init_s * init_min + param = torch.cat((init_s, init_t)).cpu().numpy() + elif scale_invariant: + init_s = 1.0 / init_max.clamp(min=1e-6) + param = init_s.cpu().numpy() + else: + raise ValueError("Unrecognized alignment.") + + return param + + def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: + if scale_invariant and shift_invariant: + s, t = np.split(param, 2) + s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) + t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + t + elif scale_invariant: + s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + else: + raise ValueError("Unrecognized alignment.") + return out + + def ensemble( + depth_aligned: torch.Tensor, return_uncertainty: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + uncertainty = None + if reduction == "mean": + prediction = torch.mean(depth_aligned, dim=0, keepdim=True) + if return_uncertainty: + uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) + elif reduction == "median": + prediction = torch.median(depth_aligned, dim=0, keepdim=True).values + if return_uncertainty: + uncertainty = torch.median(torch.abs(depth_aligned - prediction), dim=0, keepdim=True).values + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty + + def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: + cost = 0.0 + depth_aligned = align(depth, param) + + for i, j in torch.combinations(torch.arange(ensemble_size)): + diff = depth_aligned[i] - depth_aligned[j] + cost += (diff**2).mean().sqrt().item() + + if regularizer_strength > 0: + prediction, _ = ensemble(depth_aligned, return_uncertainty=False) + err_near = (0.0 - prediction.min()).abs().item() + err_far = (1.0 - prediction.max()).abs().item() + cost += (err_near + err_far) * regularizer_strength + + return cost + + def compute_param(depth: torch.Tensor): + import scipy + + depth_to_align = depth.to(torch.float32) + if max_res is not None and max(depth_to_align.shape[2:]) > max_res: + depth_to_align = MarigoldImageProcessor.resize_to_max_edge(depth_to_align, max_res, "nearest-exact") + + param = init_param(depth_to_align) + + res = scipy.optimize.minimize( + partial(cost_fn, depth=depth_to_align), + param, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) + + return res.x + + requires_aligning = scale_invariant or shift_invariant + ensemble_size = depth.shape[0] + + if requires_aligning: + param = compute_param(depth) + depth = align(depth, param) + + depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) + + depth_max = depth.max() + if scale_invariant and shift_invariant: + depth_min = depth.min() + elif scale_invariant: + depth_min = 0 + else: + raise ValueError("Unrecognized alignment.") + depth_range = (depth_max - depth_min).clamp(min=1e-6) + depth = (depth - depth_min) / depth_range + if output_uncertainty: + uncertainty /= depth_range + + return depth, uncertainty # [1,1,H,W], [1,1,H,W] diff --git a/diffusers/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/diffusers/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py new file mode 100755 index 0000000..aa9ad36 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py @@ -0,0 +1,690 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput +from ...models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + LCMScheduler, +) +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( +... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> normals = pipe(image) + +>>> vis = pipe.image_processor.visualize_normals(normals.prediction) +>>> vis[0].save("einstein_normals.png") +``` +""" + + +@dataclass +class MarigoldNormalsOutput(BaseOutput): + """ + Output class for Marigold monocular normals prediction pipeline. + + Args: + prediction (`np.ndarray`, `torch.Tensor`): + Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height + \times width$, regardless of whether the images were passed as a 4D array or a list. + uncertainty (`None`, `np.ndarray`, `torch.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages + \times 1 \times height \times width$. + latent (`None`, `torch.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. + """ + + prediction: Union[np.ndarray, torch.Tensor] + uncertainty: Union[None, np.ndarray, torch.Tensor] + latent: Union[None, torch.Tensor] + + +class MarigoldNormalsPipeline(DiffusionPipeline): + """ + Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the normals latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + use_full_z_range (`bool`, *optional*): + Whether the normals predicted by this model utilize the full range of the Z dimension, or only its positive + half. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("normals",) + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: Optional[str] = None, + use_full_z_range: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + use_full_z_range=use_full_z_range, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.use_full_z_range = use_full_z_range + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: Optional[Dict[str, Any]], + latents: Optional[torch.Tensor], + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + output_type: str, + output_uncertainty: bool, + ) -> int: + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not torch.is_tensor(latents): + raise ValueError("`latents` must be a torch.Tensor.") + if latents.dim() != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + if not all(g.device.type == generator[0].device.type for g in generator): + raise ValueError("`generator` device placement is not consistent in the list.") + elif not isinstance(generator, torch.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_resolution: Optional[int] = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: Optional[Dict[str, Any]] = None, + latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), + `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 + for Marigold-LCM models. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for + faster inference. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in + every pixel location, can be either `"closest"` or `"mean"`. + latents (`torch.Tensor`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + device = self._execution_device + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, device, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken + # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled + # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( + batch_size, 1, 1 + ) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step( + noise, t, batch_pred_latent, generator=generator + ).prev_sample # [B,4,h,w] + + pred_latents.append(batch_pred_latent) + + pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = torch.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + dim=0, + ) # [N*E,3,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,3,PH,PW] + prediction = [ + self.ensemble_normals(prediction[i], output_uncertainty, **(ensembling_kwargs or {})) + for i in range(num_images) + ] # [ [[1,3,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,3,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = torch.cat(prediction, dim=0) # [N,3,PH,PW] + if output_uncertainty: + uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # After upsampling, the native resolution normal maps are renormalized to unit length to reduce the artifacts. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,3,H,W] + prediction = self.normalize_normals(prediction) # [N,3,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldNormalsOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents + def prepare_latents( + self, + image: torch.Tensor, + latents: Optional[torch.Tensor], + generator: Optional[torch.Generator], + ensemble_size: int, + batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + def retrieve_latents(encoder_output): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + image_latent = torch.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])) + for i in range(0, image.shape[0], batch_size) + ], + dim=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + device=image_latent.device, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: + if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = torch.clip(prediction, -1.0, 1.0) + + if not self.use_full_z_range: + prediction[:, 2, :, :] *= 0.5 + prediction[:, 2, :, :] += 0.5 + + prediction = self.normalize_normals(prediction) # [B,3,H,W] + + return prediction # [B,3,H,W] + + @staticmethod + def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + if normals.dim() != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + + norm = torch.norm(normals, dim=1, keepdim=True) + normals /= norm.clamp(min=eps) + + return normals + + @staticmethod + def ensemble_normals( + normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is + the number of ensemble members for a given prediction of size `(H x W)`. + + Args: + normals (`torch.Tensor`): + Input ensemble normals maps. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"closest"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and + `"mean"`. + + Returns: + A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of + uncertainties of shape `(1, 1, H, W)`. + """ + if normals.dim() != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + if reduction not in ("closest", "mean"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + + mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] + mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] + + sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] + sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 + + uncertainty = None + if output_uncertainty: + uncertainty = sim_cos.arccos() # [E,1,H,W] + uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] + + if reduction == "mean": + return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] + + closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] + closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] + closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] + + return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] diff --git a/diffusers/src/diffusers/pipelines/musicldm/__init__.py b/diffusers/src/diffusers/pipelines/musicldm/__init__.py new file mode 100755 index 0000000..ed71eeb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/musicldm/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_musicldm"] = ["MusicLDMPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_musicldm import MusicLDMPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/diffusers/src/diffusers/pipelines/musicldm/pipeline_musicldm.py new file mode 100755 index 0000000..728635d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/musicldm/pipeline_musicldm.py @@ -0,0 +1,635 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + ClapFeatureExtractor, + ClapModel, + ClapTextModelWithProjection, + RobertaTokenizer, + RobertaTokenizerFast, + SpeechT5HifiGan, +) + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_librosa_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin + + +if is_librosa_available(): + import librosa + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import MusicLDMPipeline + >>> import torch + >>> import scipy + + >>> repo_id = "ucsd-reach/musicldm" + >>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" + >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] + + >>> # save the audio sample as a .wav file + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) + ``` +""" + + +class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-audio generation using MusicLDM. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.ClapModel`]): + Frozen text-audio embedding model (`ClapTextModel`), specifically the + [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. + tokenizer ([`PreTrainedTokenizer`]): + A [`~transformers.RobertaTokenizer`] to tokenize text. + feature_extractor ([`~transformers.ClapFeatureExtractor`]): + Feature extractor to compute mel-spectrograms from audio waveforms. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`~transformers.SpeechT5HifiGan`]): + Vocoder of class `SpeechT5HifiGan`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: Union[ClapTextModelWithProjection, ClapModel], + tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], + feature_extractor: Optional[ClapFeatureExtractor], + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def _encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLAP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder.get_text_features( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device) + + ( + bs_embed, + seq_len, + ) = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder.get_text_features( + uncond_input_ids, + attention_mask=attention_mask, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + # Copied from diffusers.pipelines.audioldm2.pipeline_audioldm2.AudioLDM2Pipeline.score_waveforms + def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype): + if not is_librosa_available(): + logger.info( + "Automatic scoring of the generated audio waveforms against the input prompt text requires the " + "`librosa` package to resample the generated waveforms. Returning the audios in the order they were " + "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`." + ) + return audio + inputs = self.tokenizer(text, return_tensors="pt", padding=True) + resampled_audio = librosa.resample( + audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate + ) + inputs["input_features"] = self.feature_extractor( + list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate + ).input_features.type(dtype) + inputs = inputs.to(device) + + # compute the audio-text similarity score using the CLAP model + logits_per_text = self.text_encoder(**inputs).logits_per_text + # sort by the highest matching generations per prompt + indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt] + audio = torch.index_select(audio, 0, indices.reshape(-1).cpu()) + return audio + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.check_inputs + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = [ + self.text_encoder.text_model, + self.text_encoder.text_projection, + self.unet, + self.vae, + self.vocoder, + self.text_encoder, + ] + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_length_in_s: Optional[float] = None, + num_inference_steps: int = 200, + guidance_scale: float = 2.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + output_type: Optional[str] = "np", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_length_in_s (`int`, *optional*, defaults to 10.24): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 200): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 2.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, the text encoding + model is a joint text-audio model ([`~transformers.ClapModel`]), and the tokenizer is a + `[~transformers.ClapProcessor]`, then automatic scoring will be performed between the generated outputs + and the input text. This scoring ranks the generated waveforms based on their cosine similarity to text + input in the joint text-audio embedding space. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=None, + class_labels=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + self.maybe_free_model_hooks() + + # 8. Post-processing + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + # 9. Automatic scoring + if num_waveforms_per_prompt > 1 and prompt is not None: + audio = self.score_waveforms( + text=prompt, + audio=audio, + num_waveforms_per_prompt=num_waveforms_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/src/diffusers/pipelines/onnx_utils.py b/diffusers/src/diffusers/pipelines/onnx_utils.py new file mode 100755 index 0000000..11f2241 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/onnx_utils.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional, Union + +import numpy as np +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import validate_hf_hub_args + +from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging + + +if is_onnx_available(): + import onnxruntime as ort + + +logger = logging.get_logger(__name__) + +ORT_TO_NP_TYPE = { + "tensor(bool)": np.bool_, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(int16)": np.int16, + "tensor(uint16)": np.uint16, + "tensor(int32)": np.int32, + "tensor(uint32)": np.uint32, + "tensor(int64)": np.int64, + "tensor(uint64)": np.uint64, + "tensor(float16)": np.float16, + "tensor(float)": np.float32, + "tensor(double)": np.float64, +} + + +class OnnxRuntimeModel: + def __init__(self, model=None, **kwargs): + logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") + self.model = model + self.model_save_dir = kwargs.get("model_save_dir", None) + self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) + + def __call__(self, **kwargs): + inputs = {k: np.array(v) for k, v in kwargs.items()} + return self.model.run(None, inputs) + + @staticmethod + def load_model(path: Union[str, Path], provider=None, sess_options=None): + """ + Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` + + Arguments: + path (`str` or `Path`): + Directory from which to load + provider(`str`, *optional*): + Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` + """ + if provider is None: + logger.info("No onnxruntime provider specified, using CPUExecutionProvider") + provider = "CPUExecutionProvider" + + return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the + latest_model_name. + + Arguments: + save_directory (`str` or `Path`): + Directory where to save the model file. + file_name(`str`, *optional*): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the + model with a different name. + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + + src_path = self.model_save_dir.joinpath(self.latest_model_name) + dst_path = Path(save_directory).joinpath(model_file_name) + try: + shutil.copyfile(src_path, dst_path) + except shutil.SameFileError: + pass + + # copy external weights (for models >2GB) + src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + if src_path.exists(): + dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + try: + shutil.copyfile(src_path, dst_path) + except shutil.SameFileError: + pass + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + **kwargs, + ): + """ + Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class + method.: + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # saving model weights/files + self._save_pretrained(save_directory, **kwargs) + + @classmethod + @validate_hf_hub_args + def _from_pretrained( + cls, + model_id: Union[str, Path], + token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["ort.SessionOptions"] = None, + **kwargs, + ): + """ + Load a model from a directory or the HF Hub. + + Arguments: + model_id (`str` or `Path`): + Directory from which to load + token (`str` or `bool`): + Is needed to load models from a private or gated repository + revision (`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (`Union[str, Path]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + file_name(`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load + different model files from the same repository or directory. + provider(`str`): + The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. + kwargs (`Dict`, *optional*): + kwargs will be passed to the model during initialization + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if os.path.isdir(model_id): + model = OnnxRuntimeModel.load_model( + Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options + ) + kwargs["model_save_dir"] = Path(model_id) + # load model from hub + else: + # download model + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_name, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + ) + kwargs["model_save_dir"] = Path(model_cache_path).parent + kwargs["latest_model_name"] = Path(model_cache_path).name + model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) + return cls(model=model, **kwargs) + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + model_id: Union[str, Path], + force_download: bool = True, + token: Optional[str] = None, + cache_dir: Optional[str] = None, + **model_kwargs, + ): + revision = None + if len(str(model_id).split("@")) == 2: + model_id, revision = model_id.split("@") + + return cls._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + token=token, + **model_kwargs, + ) diff --git a/diffusers/src/diffusers/pipelines/pag/__init__.py b/diffusers/src/diffusers/pipelines/pag/__init__.py new file mode 100755 index 0000000..d8842ce --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/__init__.py @@ -0,0 +1,69 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] + _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] + _import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] + _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"] + _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] + _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] + _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"] + _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] + _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] + _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline + from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline + from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline + from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline + from .pipeline_pag_kolors import KolorsPAGPipeline + from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline + from .pipeline_pag_sd import StableDiffusionPAGPipeline + from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline + from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline + from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline + from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/pag/pag_utils.py b/diffusers/src/diffusers/pipelines/pag/pag_utils.py new file mode 100755 index 0000000..728f730 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pag_utils.py @@ -0,0 +1,237 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, List, Tuple, Union + +import torch +import torch.nn as nn + +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PAGMixin: + r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" + + def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + pag_attn_processors = self._pag_attn_processors + if pag_attn_processors is None: + raise ValueError( + "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." + ) + + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + if hasattr(self, "unet"): + model: nn.Module = self.unet + else: + model: nn.Module = self.transformer + + def is_self_attn(module: nn.Module) -> bool: + r""" + Check if the module is self-attention module based on its name. + """ + return isinstance(module, Attention) and not module.is_cross_attention + + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + for layer_id in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the unet model + target_modules = [] + + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(module) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + logger.debug(f"Applying PAG to layer: {name}") + target_modules.append(module) + + if len(target_modules) == 0: + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") + + for module in target_modules: + module.processor = pag_attn_proc + + def _get_pag_scale(self, t): + r""" + Get the scale factor for the perturbed attention guidance at timestep `t`. + """ + + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) + if signal_scale < 0: + signal_scale = 0 + return signal_scale + else: + return self.pag_scale + + def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + r""" + Apply perturbed attention guidance to the noise prediction. + + Args: + noise_pred (torch.Tensor): The noise prediction tensor. + do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. + guidance_scale (float): The scale factor for the guidance term. + t (int): The current time step. + + Returns: + torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. + """ + pag_scale = self._get_pag_scale(t) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) + else: + noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) + noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + return noise_pred + + def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): + """ + Prepares the perturbed attention guidance for the PAG model. + + Args: + cond (torch.Tensor): The conditional input tensor. + uncond (torch.Tensor): The unconditional input tensor. + do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. + + Returns: + torch.Tensor: The prepared perturbed attention guidance tensor. + """ + + cond = torch.cat([cond] * 2, dim=0) + + if do_classifier_free_guidance: + cond = torch.cat([uncond, cond], dim=0) + return cond + + def set_pag_applied_layers( + self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), + ): + r""" + Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `List[str]`): + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" + pag_attn_processors: + (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). + """ + + if not hasattr(self, "_pag_attn_processors"): + self._pag_attn_processors = None + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") + + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) + + self.pag_applied_layers = pag_applied_layers + self._pag_attn_processors = pag_attn_processors + + @property + def pag_scale(self) -> float: + r"""Get the scale factor for the perturbed attention guidance.""" + return self._pag_scale + + @property + def pag_adaptive_scale(self) -> float: + r"""Get the adaptive scale factor for the perturbed attention guidance.""" + return self._pag_adaptive_scale + + @property + def do_pag_adaptive_scaling(self) -> bool: + r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" + return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + def do_perturbed_attention_guidance(self) -> bool: + r"""Check if the perturbed attention guidance is enabled.""" + return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model + with the key as the name of the layer. + """ + + if self._pag_attn_processors is None: + return {} + + valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} + + processors = {} + # We could have iterated through the self.components.items() and checked if a component is + # `ModelMixin` subclassed but that can include a VAE too. + if hasattr(self, "unet"): + denoiser_module = self.unet + elif hasattr(self, "transformer"): + denoiser_module = self.transformer + else: + raise ValueError("No denoiser module found.") + + for name, proc in denoiser_module.attn_processors.items(): + if proc.__class__ in valid_attn_processors: + processors[name] = proc + + return processors diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py new file mode 100755 index 0000000..9bac883 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -0,0 +1,1329 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..controlnet.multicontrolnet import MultiControlNetModel +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting", + ... guidance_scale=7.5, + ... generator=generator, + ... image=canny_image, + ... pag_scale=10, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionControlNetPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: Union[str, List[str]] = "mid", + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + controlnet_prompt_embeds = prompt_embeds + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image + + image = images if isinstance(image, list) else images[0] + + # 8. Denoising loop + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py new file mode 100755 index 0000000..247fc90 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -0,0 +1,1612 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from ..controlnet.multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, pag_scale=0.3 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlNetPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["down.block_2", "up.block_1.attentions_0"], "mid" + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image + + image = images if isinstance(image, list) else images[0] + + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=False, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py new file mode 100755 index 0000000..6639848 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -0,0 +1,1685 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from ..controlnet.multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # pip install accelerate transformers safetensors diffusers + + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation + >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL + >>> from diffusers.utils import load_image + + + >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") + >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-depth-sdxl-1.0-small", + ... variant="fp16", + ... use_safetensors="True", + ... torch_dtype=torch.float16, + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPAGImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... variant="fp16", + ... use_safetensors=True, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe.enable_model_cpu_offload() + + + >>> def get_depth_map(image): + ... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + ... with torch.no_grad(), torch.autocast("cuda"): + ... depth_map = depth_estimator(image).predicted_depth + + ... depth_map = torch.nn.fuctional.interpolate( + ... depth_map.unsqueeze(1), + ... size=(1024, 1024), + ... mode="bicubic", + ... align_corners=False, + ... ) + ... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_map = (depth_map - depth_min) / (depth_max - depth_min) + ... image = torch.cat([depth_map] * 3, dim=1) + ... image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + ... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + ... return image + + + >>> prompt = "A robot, 4k photo" + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((1024, 1024)) + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> depth_image = get_depth_map(image) + + >>> images = pipe( + ... prompt, + ... image=image, + ... control_image=depth_image, + ... strength=0.99, + ... num_inference_steps=50, + ... controlnet_conditioning_scale=controlnet_conditioning_scale, + ... ).images + >>> images[0].save(f"robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLControlNetPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img.StableDiffusionXLControlNetImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.8, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also + be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in + init, images must be passed as a list such that each element of the list can be correctly batched for + input to a single controlnet. + height (`int`, *optional*, defaults to the size of control_image): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to the size of control_image): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple` containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image and controlnet_conditioning_image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_images.append(control_image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(control_image, list): + original_size = original_size or control_image[0].shape[-2:] + else: + original_size = original_size or control_image.shape[-2:] + target_size = target_size or (height, width) + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + control_images = control_image if isinstance(control_image, list) else [control_image] + for i, single_image in enumerate(control_images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + control_images[i] = single_image + + control_image = control_images if isinstance(control_image, list) else control_images[0] + + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=False, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py new file mode 100755 index 0000000..63126cc --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -0,0 +1,953 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0 +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=[14], + ... ).to("cuda") + + >>> # prompt = "an astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: Optional[StableDiffusionSafetyChecker] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + requires_safety_checker: bool = True, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, + pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + prompt_embeds_2 = self._prepare_perturbed_attention_guidance( + prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance + ) + prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance( + prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance + ) + add_time_ids = torch.cat([add_time_ids] * 3, dim=0) + style = torch.cat([style] * 3, dim=0) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_kolors.py new file mode 100755 index 0000000..3255bfd --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_kolors.py @@ -0,0 +1,1136 @@ +# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..kolors.pipeline_output import KolorsPipelineOutput +from ..kolors.text_encoder import ChatGLMModel +from ..kolors.tokenizer import ChatGLMTokenizer +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=["down.block_2.attentions_1", "up.block_0.attentions_1"], + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = ( + ... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'" + ... ) + >>> image = pipe(prompt, guidance_scale=5.5, pag_scale=1.5).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsPAGPipeline( + DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, PAGMixin +): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + pag_applied_layers (`str` or `List[str]``, *optional*, defaults to `"mid"`): + Set the transformer attention layers where to apply the perturbed attention guidance. Can be a string or a + list of strings with "down", "mid", "up", a whole transformer block or specific transformer block attention + layers, e.g.: + ["mid"] ["down", "mid"] ["down", "mid", "up.block_1"] ["down", "mid", "up.block_1.attentions_0", + "up.block_1.attentions_1"] + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + pag_applied_layers: Union[str, List[str]] = "mid", + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.check_inputs + def check_inputs( + self, + prompt, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py new file mode 100755 index 0000000..8e5e6cb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -0,0 +1,872 @@ +# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pag_utils import PAGMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", + ... torch_dtype=torch.float16, + ... pag_applied_layers=["blocks.14"], + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A small cactus with a happy face in the Sahara desert" + >>> image = pipe(prompt, pag_scale=4.0, guidance_scale=1.0).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation + using PixArt-Sigma. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: KarrasDiffusionSchedulers, + pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, do_classifier_free_guidance + ) + elif do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, do_classifier_free_guidance, guidance_scale, current_timestep + ) + elif do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd.py new file mode 100755 index 0000000..c6a4f7f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -0,0 +1,1050 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: Union[str, List[str]] = "mid", + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + None, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None + ) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py new file mode 100755 index 0000000..3035509 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -0,0 +1,985 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0 +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=["blocks.13"], + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt, guidance_scale=5.0, pag_scale=0.7).images[0] + >>> image.save("sd3_pag.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation + using Stable Diffusion 3. + + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale # + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + pooled_prompt_embeds = self._prepare_perturbed_attention_guidance( + pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py new file mode 100755 index 0000000..1e81fa3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -0,0 +1,866 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..animatediff.pipeline_output import AnimateDiffPipelineOutput +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pag_utils import PAGMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AnimateDiffPAGPipeline, MotionAdapter, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2" + >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id) + >>> scheduler = DDIMScheduler.from_pretrained( + ... model_id, subfolder="scheduler", beta_schedule="linear", steps_offset=1, clip_sample=False + ... ) + >>> pipe = AnimateDiffPAGPipeline.from_pretrained( + ... model_id, + ... motion_adapter=motion_adapter, + ... scheduler=scheduler, + ... pag_applied_layers=["mid"], + ... torch_dtype=torch.float16, + ... ).to("cuda") + + >>> video = pipe( + ... prompt="car, futuristic cityscape with neon lights, street, no human", + ... negative_prompt="low quality, bad quality", + ... num_inference_steps=25, + ... guidance_scale=6.0, + ... pag_scale=3.0, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> export_to_gif(video, "animatediff_pag.gif") + ``` +""" + + +class AnimateDiffPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, + PAGMixin, +): + r""" + Pipeline for text-to-video generation using + [AnimateDiff](https://huggingface.co/docs/diffusers/en/api/pipelines/animatediff) and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + motion_adapter: MotionAdapter, + scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"] + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pia.pipeline_pia.PIAPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_chunk_size: int = 16, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat( + [latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0]) + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py new file mode 100755 index 0000000..18fc06c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -0,0 +1,1333 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"],["down.block_1"],["up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py new file mode 100755 index 0000000..dc85aaa --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -0,0 +1,1532 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-1.0", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + num_inference_steps, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py new file mode 100755 index 0000000..f5ebf43 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -0,0 +1,1764 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForInpainting + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... enable_pag=True, + ... ) + >>> pipe.to("cuda") + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, + ... image=init_image, + ... mask_image=mask_image, + ... num_inference_steps=50, + ... strength=0.80, + ... pag_scale=0.3, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + None, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is not None: + masked_image = masked_image_latents + elif init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if self.denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + mask, _ = mask.chunk(2) + masked_image_latents, _ = masked_image_latents.chunk(2) + mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) + masked_image_latents = self._prepare_perturbed_attention_guidance( + masked_image_latents, masked_image_latents, self.do_classifier_free_guidance + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 11.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_perturbed_attention_guidance: + init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) + else: + init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/paint_by_example/__init__.py b/diffusers/src/diffusers/pipelines/paint_by_example/__init__.py new file mode 100755 index 0000000..aaa775f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/paint_by_example/__init__.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["image_encoder"] = ["PaintByExampleImageEncoder"] + _import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .image_encoder import PaintByExampleImageEncoder + from .pipeline_paint_by_example import PaintByExamplePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/paint_by_example/image_encoder.py b/diffusers/src/diffusers/pipelines/paint_by_example/image_encoder.py new file mode 100755 index 0000000..2fd0338 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/paint_by_example/image_encoder.py @@ -0,0 +1,67 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import CLIPPreTrainedModel, CLIPVisionModel + +from ...models.attention import BasicTransformerBlock +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PaintByExampleImageEncoder(CLIPPreTrainedModel): + def __init__(self, config, proj_size=None): + super().__init__(config) + self.proj_size = proj_size or getattr(config, "projection_dim", 768) + + self.model = CLIPVisionModel(config) + self.mapper = PaintByExampleMapper(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + self.proj_out = nn.Linear(config.hidden_size, self.proj_size) + + # uncondition for scaling + self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) + + def forward(self, pixel_values, return_uncond_vector=False): + clip_output = self.model(pixel_values=pixel_values) + latent_states = clip_output.pooler_output + latent_states = self.mapper(latent_states[:, None]) + latent_states = self.final_layer_norm(latent_states) + latent_states = self.proj_out(latent_states) + if return_uncond_vector: + return latent_states, self.uncond_vector + + return latent_states + + +class PaintByExampleMapper(nn.Module): + def __init__(self, config): + super().__init__() + num_layers = (config.num_hidden_layers + 1) // 5 + hid_size = config.hidden_size + num_heads = 1 + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states): + for block in self.blocks: + hidden_states = block(hidden_states) + + return hidden_states diff --git a/diffusers/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/diffusers/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py new file mode 100755 index 0000000..b225fd7 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -0,0 +1,626 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .image_encoder import PaintByExampleImageEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Batched mask + if mask.shape[0] == image.shape[0]: + mask = mask.unsqueeze(1) + else: + mask = mask.unsqueeze(0) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + assert mask.shape[1] == 1, "Mask image must have a single channel" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # paint-by-example inverses the mask + mask = 1 - mask + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = [image] + + image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, PIL.Image.Image): + mask = [mask] + + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + + # paint-by-example inverses the mask + mask = 1 - mask + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * mask + + return mask, masked_image + + +class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + + + 🧪 This is an experimental feature! + + + + Pipeline for image-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`PaintByExampleImageEncoder`]): + Encodes the example input image. The `unet` is conditioned on the example image instead of a text prompt. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + + """ + + # TODO: feature_extractor is required to encode initial images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + + model_cpu_offload_seq = "unet->vae" + _exclude_from_cpu_offload = ["image_encoder"] + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: PaintByExampleImageEncoder, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + example_image: Union[torch.Tensor, PIL.Image.Image], + image: Union[torch.Tensor, PIL.Image.Image], + mask_image: Union[torch.Tensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + example_image (`torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + An example image to guide image generation. + image (`torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + `Image` or tensor representing an image batch to be inpainted (parts of the image are masked out with + `mask_image` and repainted according to `prompt`). + mask_image (`torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted, + while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Example: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + >>> from diffusers import PaintByExamplePipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" + ... ) + >>> mask_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" + ... ) + >>> example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + >>> example_image = download_image(example_url).resize((512, 512)) + + >>> pipe = PaintByExamplePipeline.from_pretrained( + ... "Fantasy-Studio/Paint-by-Example", + ... torch_dtype=torch.float16, + ... ) + >>> pipe = pipe.to("cuda") + + >>> image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0] + >>> image + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + height, width = masked_image.shape[-2:] + + # 3. Check inputs + self.check_inputs(example_image, height, width, callback_steps) + + # 4. Encode input image + image_embeddings = self._encode_image( + example_image, device, num_images_per_prompt, do_classifier_free_guidance + ) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + image_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + self.maybe_free_model_hooks() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/pia/__init__.py b/diffusers/src/diffusers/pipelines/pia/__init__.py new file mode 100755 index 0000000..16e8004 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pia/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_pia import PIAPipeline, PIAPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/pia/pipeline_pia.py b/diffusers/src/diffusers/pipelines/pia/pipeline_pia.py new file mode 100755 index 0000000..b7dfcd3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pia/pipeline_pia.py @@ -0,0 +1,944 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter") + >>> pipe = PIAPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" + ... ) + >>> image = image.resize((512, 512)) + >>> prompt = "cat in a hat" + >>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted" + >>> generator = torch.Generator("cpu").manual_seed(0) + >>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator) + >>> frames = output.frames[0] + >>> export_to_gif(frames, "pia-animation.gif") + ``` +""" + +RANGE_LIST = [ + [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion + [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion + [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion + [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # Loop + [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # Loop + [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # Loop + [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion + [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion + [0.5, 0.2], # Style Transfer Large Motion +] + + +def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int): + assert num_frames > 0, "video_length should be greater than 0" + + assert num_frames > cond_frame, "video_length should be greater than cond_frame" + + range_list = RANGE_LIST + + assert motion_scale < len(range_list), f"motion_scale type{motion_scale} not implemented" + + coef = range_list[motion_scale] + coef = coef + ([coef[-1]] * (num_frames - len(coef))) + + order = [abs(i - cond_frame) for i in range(num_frames)] + coef = [coef[order[i]] for i in range(num_frames)] + + return coef + + +@dataclass +class PIAPipelineOutput(BaseOutput): + r""" + Output class for PIAPipeline. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, NumPy array of + shape `(batch_size, num_frames, channels, height, width, Torch tensor of shape `(batch_size, num_frames, + channels, height, width)`. + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] + + +class PIAPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, + FreeInitMixin, +): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + motion_adapter: Optional[MotionAdapter] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_masked_condition( + self, + image, + batch_size, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + motion_scale=0, + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + _, _, _, scaled_height, scaled_width = shape + + image = self.video_processor.preprocess(image) + image = image.to(device, dtype) + + if isinstance(generator, list): + image_latent = [ + self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size) + ] + image_latent = torch.cat(image_latent, dim=0) + else: + image_latent = self.vae.encode(image).latent_dist.sample(generator) + + image_latent = image_latent.to(device=device, dtype=dtype) + image_latent = torch.nn.functional.interpolate(image_latent, size=[scaled_height, scaled_width]) + image_latent_padding = image_latent.clone() * self.vae.config.scaling_factor + + mask = torch.zeros((batch_size, 1, num_frames, scaled_height, scaled_width)).to(device=device, dtype=dtype) + mask_coef = prepare_mask_coef_by_statistics(num_frames, 0, motion_scale) + masked_image = torch.zeros(batch_size, 4, num_frames, scaled_height, scaled_width).to( + device=device, dtype=self.unet.dtype + ) + for f in range(num_frames): + mask[:, :, f, :, :] = mask_coef[f] + masked_image[:, :, f, :, :] = image_latent_padding.clone() + + mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask + masked_image = torch.cat([masked_image] * 2) if self.do_classifier_free_guidance else masked_image + + return mask, masked_image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + strength: float = 1.0, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + motion_scale: int = 0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to be used for video generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + motion_scale: (`int`, *optional*, defaults to 0): + Parameter that controls the amount and type of motion that is added to the image. Increasing the value + increases the amount of motion, while specific ranges of values control the type of motion that is + added. Must be between 0 and 8. Set between 0-2 to only increase the amount of motion. Set between 3-5 + to create looping motion. Set between 6-8 to perform motion with image style transfer. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + 4, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents=latents, + ) + mask, masked_image = self.prepare_masked_condition( + image, + batch_size * num_videos_per_prompt, + 4, + num_frames=num_frames, + height=height, + width=width, + dtype=self.unet.dtype, + device=device, + generator=generator, + motion_scale=motion_scale, + ) + if strength < 1.0: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + latents = self.scheduler.add_noise(masked_image[0], noise, latent_timestep) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 8. Denoising loop + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return PIAPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/pipeline_flax_utils.py b/diffusers/src/diffusers/pipelines/pipeline_flax_utils.py new file mode 100755 index 0000000..c4c2128 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pipeline_flax_utils.py @@ -0,0 +1,611 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from typing import Any, Dict, List, Optional, Union + +import flax +import numpy as np +import PIL.Image +from flax.core.frozen_dict import FrozenDict +from huggingface_hub import create_repo, snapshot_download +from huggingface_hub.utils import validate_hf_hub_args +from PIL import Image +from tqdm.auto import tqdm + +from ..configuration_utils import ConfigMixin +from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin +from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin +from ..utils import ( + CONFIG_NAME, + BaseOutput, + PushToHubMixin, + http_user_agent, + is_transformers_available, + logging, +) + + +if is_transformers_available(): + from transformers import FlaxPreTrainedModel + +INDEX_FILE = "diffusion_flax_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "FlaxModelMixin": ["save_pretrained", "from_pretrained"], + "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], + "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +def import_flax_or_no_model(module, class_name): + try: + # 1. First make sure that if a Flax object is present, import this one + class_obj = getattr(module, "Flax" + class_name) + except AttributeError: + # 2. If this doesn't work, it's not a model and we don't append "Flax" + class_obj = getattr(module, class_name) + except AttributeError: + raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") + + return class_obj + + +@flax.struct.dataclass +class FlaxImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): + r""" + Base class for Flax-based pipelines. + + [`FlaxDiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and + provides methods for loading, downloading and saving models. It also includes methods to: + + - enable/disable the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + """ + + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + if module is None: + register_dict = {name: (None, None)} + else: + # retrieve library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params: Union[Dict, FrozenDict], + push_to_hub: bool = False, + **kwargs, + ): + # TODO: handle inference_state + """ + Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its + class implements both a save and loading method. The pipeline is easily reloaded using the + [`~FlaxDiffusionPipeline.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + if sub_model is None: + # edge case for saving a pipeline with safety_checker=None + continue + + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) + + if expects_params: + save_method( + os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] + ) + else: + save_method(os.path.join(save_directory, pipeline_component_name)) + + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a Flax-based diffusion pipeline from pretrained pipeline weights. + + The pipeline is set in evaluation mode (`model.eval()) by default and dropout modules are deactivated. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + ``` + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + using [`~FlaxDiffusionPipeline.save_pretrained`]. + dtype (`str` or `jnp.dtype`, *optional*): + Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is + automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components) of the specific pipeline + class. The overwritten components are passed directly to the pipelines `__init__` method. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import FlaxDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> # Requires to be logged in to Hugging Face hub, + >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", + ... variant="bf16", + ... dtype=jnp.bfloat16, + ... ) + + >>> # Download pipeline, but use a different scheduler + >>> from diffusers import FlaxDPMSolverMultistepScheduler + + >>> model_id = "runwayml/stable-diffusion-v1-5" + >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... ) + + >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained( + ... model_id, variant="bf16", dtype=jnp.bfloat16, scheduler=dpmpp + ... ) + >>> dpm_params["scheduler"] = dpmpp_state + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + from_pt = kwargs.pop("from_pt", False) + use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) + split_head_dim = kwargs.pop("split_head_dim", False) + dtype = kwargs.pop("dtype", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + + ignore_patterns = ["*.bin", "*.safetensors"] if not from_pt else [] + ignore_patterns += ["*.onnx", "*.onnx_data", "*.xml", "*.pb"] + + if cls != FlaxDiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + requested_pipeline_class = ( + requested_pipeline_class + if requested_pipeline_class.startswith("Flax") + else "Flax" + requested_pipeline_class + ) + + user_agent = {"pipeline_class": requested_pipeline_class} + user_agent = http_user_agent(user_agent) + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.load_config(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != FlaxDiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + class_name = ( + config_dict["_class_name"] + if config_dict["_class_name"].startswith("Flax") + else "Flax" + config_dict["_class_name"] + ) + pipeline_class = getattr(diffusers_module, class_name) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + # define init kwargs + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + # Throw nice warnings / errors for fast accelerate loading + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) + + # inference_params + params = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + if class_name is None: + # edge case for when the pipeline was saved with safety_checker=None + init_kwargs[name] = None + continue + + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + sub_model_should_be_defined = True + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + elif passed_class_obj[name] is None: + logger.warning( + f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" + f" that this might lead to problems when using {pipeline_class} and is not recommended." + ) + sub_model_should_be_defined = False + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = import_flax_or_no_model(pipeline_module, class_name) + + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = import_flax_or_no_model(library, class_name) + + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + if loaded_sub_model is None and sub_model_should_be_defined: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loadable_folder = os.path.join(cached_folder, name) + else: + loaded_sub_model = cached_folder + + if issubclass(class_obj, FlaxModelMixin): + loaded_sub_model, loaded_params = load_method( + loadable_folder, + from_pt=from_pt, + use_memory_efficient_attention=use_memory_efficient_attention, + split_head_dim=split_head_dim, + dtype=dtype, + ) + params[name] = loaded_params + elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): + if from_pt: + # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here + loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) + loaded_params = loaded_sub_model.params + del loaded_sub_model._params + else: + loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) + params[name] = loaded_params + elif issubclass(class_obj, FlaxSchedulerMixin): + loaded_sub_model, scheduler_state = load_method(loadable_folder) + params[name] = scheduler_state + else: + loaded_sub_model = load_method(loadable_folder) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + + if len(missing_modules) > 0 and missing_modules <= set(passed_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + model = pipeline_class(**init_kwargs, dtype=dtype) + return model, params + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + + The `self.components` property can be useful to run different pipelines with the same weights and + configurations to not have to re-allocate memory. + + Examples: + + ```py + >>> from diffusers import ( + ... FlaxStableDiffusionPipeline, + ... FlaxStableDiffusionImg2ImgPipeline, + ... ) + + >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16 + ... ) + >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) + ``` + + Returns: + A dictionary containing all the modules needed to initialize the pipeline. + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components} are defined." + ) + + return components + + @staticmethod + def numpy_to_pil(images): + """ + Convert a NumPy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + # TODO: make it compatible with jax.lax + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/diffusers/src/diffusers/pipelines/pipeline_loading_utils.py b/diffusers/src/diffusers/pipelines/pipeline_loading_utils.py new file mode 100755 index 0000000..318599f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pipeline_loading_utils.py @@ -0,0 +1,838 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import os +import re +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +from huggingface_hub import ModelCard, model_info +from huggingface_hub.utils import validate_hf_hub_args +from packaging import version + +from .. import __version__ +from ..utils import ( + FLAX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + deprecate, + get_class_from_dynamic_module, + is_accelerate_available, + is_peft_available, + is_transformers_available, + logging, +) +from ..utils.torch_utils import is_compiled_module + + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel + from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME + from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME + from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME + +if is_accelerate_available(): + import accelerate + from accelerate import dispatch_model + from accelerate.hooks import remove_hook_from_module + from accelerate.utils import compute_module_sizes, get_max_memory + + +INDEX_FILE = "diffusion_pytorch_model.bin" +CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" +DUMMY_MODULES_FOLDER = "diffusers.utils" +TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" +CONNECTED_PIPES_KEYS = ["prior"] + +logger = logging.get_logger(__name__) + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool: + """ + Checking for safetensors compatibility: + - The model is safetensors compatible only if there is a safetensors file for each model component present in + filenames. + + Converting default pytorch serialized filenames to safetensors serialized filenames: + - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" + - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" + extension is replaced with ".safetensors" + """ + passed_components = passed_components or [] + if folder_names is not None: + filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} + + # extract all components of the pipeline and their associated files + components = {} + for filename in filenames: + if not len(filename.split("/")) == 2: + continue + + component, component_filename = filename.split("/") + if component in passed_components: + continue + + components.setdefault(component, []) + components[component].append(component_filename) + + # iterate over all files of a component + # check if safetensor files exist for that component + # if variant is provided check if the variant of the safetensors exists + for component, component_filenames in components.items(): + matches = [] + for component_filename in component_filenames: + filename, extension = os.path.splitext(component_filename) + + match_exists = extension == ".safetensors" + matches.append(match_exists) + + if not any(matches): + return False + + return True + + +def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] + + if is_transformers_available(): + weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] + + # model_pytorch, diffusion_model_pytorch, ... + weight_prefixes = [w.split(".")[0] for w in weight_names] + # .bin, .safetensors, ... + weight_suffixs = [w.split(".")[-1] for w in weight_names] + # -00001-of-00002 + transformers_index_format = r"\d{5}-of-\d{5}" + + if variant is not None: + # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` + variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.fp16.json` + variant_index_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) + + # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` + non_variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.json` + non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") + + if variant is not None: + variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} + variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} + variant_filenames = variant_weights | variant_indexes + else: + variant_filenames = set() + + non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} + non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} + non_variant_filenames = non_variant_weights | non_variant_indexes + + # all variant filenames will be used by default + usable_filenames = set(variant_filenames) + + def convert_to_variant(filename): + if "index" in filename: + variant_filename = filename.replace("index", f"index.{variant}") + elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: + variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" + else: + variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" + return variant_filename + + for f in non_variant_filenames: + variant_filename = convert_to_variant(f) + if variant_filename not in usable_filenames: + usable_filenames.add(f) + + return usable_filenames, variant_filenames + + +@validate_hf_hub_args +def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): + info = model_info( + pretrained_model_name_or_path, + token=token, + revision=None, + ) + filenames = {sibling.rfilename for sibling in info.siblings} + comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) + comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] + + if set(model_filenames).issubset(set(comp_model_filenames)): + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + else: + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", + FutureWarning, + ) + + +def _unwrap_model(model): + """Unwraps a model.""" + if is_compiled_module(model): + model = model._orig_mod + + if is_peft_available(): + from peft import PeftModel + + if isinstance(model, PeftModel): + model = model.base_model.model + + return model + + +def maybe_raise_or_warn( + library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module +): + """Simple helper method to raise or warn in case incorrect module has been passed""" + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + sub_model = passed_class_obj[name] + unwrapped_sub_model = _unwrap_model(sub_model) + model_cls = unwrapped_sub_model.__class__ + + if not issubclass(model_cls, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" + ) + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + +def get_class_obj_and_candidates( + library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None +): + """Simple helper method to retrieve class object of module as well as potential parent class objects""" + component_folder = os.path.join(cache_dir, component_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + + class_obj = getattr(pipeline_module, class_name) + class_candidates = {c: class_obj for c in importable_classes.keys()} + elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): + # load custom component + class_obj = get_class_from_dynamic_module( + component_folder, module_file=library_name + ".py", class_name=class_name + ) + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + return class_obj, class_candidates + + +def _get_custom_pipeline_class( + custom_pipeline, + repo_id=None, + hub_revision=None, + class_name=None, + cache_dir=None, + revision=None, +): + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + elif repo_id is not None: + file_name = f"{custom_pipeline}.py" + custom_pipeline = repo_id + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + + if repo_id is not None and hub_revision is not None: + # if we load the pipeline code from the Hub + # make sure to overwrite the `revision` + revision = hub_revision + + return get_class_from_dynamic_module( + custom_pipeline, + module_file=file_name, + class_name=class_name, + cache_dir=cache_dir, + revision=revision, + ) + + +def _get_pipeline_class( + class_obj, + config=None, + load_connected_pipeline=False, + custom_pipeline=None, + repo_id=None, + hub_revision=None, + class_name=None, + cache_dir=None, + revision=None, +): + if custom_pipeline is not None: + return _get_custom_pipeline_class( + custom_pipeline, + repo_id=repo_id, + hub_revision=hub_revision, + class_name=class_name, + cache_dir=cache_dir, + revision=revision, + ) + + if class_obj.__name__ != "DiffusionPipeline": + return class_obj + + diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) + class_name = class_name or config["_class_name"] + if not class_name: + raise ValueError( + "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." + ) + + class_name = class_name[4:] if class_name.startswith("Flax") else class_name + + pipeline_cls = getattr(diffusers_module, class_name) + + if load_connected_pipeline: + from .auto_pipeline import _get_connected_pipeline + + connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) + if connected_pipeline_cls is not None: + logger.info( + f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" + ) + else: + logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") + + pipeline_cls = connected_pipeline_cls or pipeline_cls + + return pipeline_cls + + +def _load_empty_model( + library_name: str, + class_name: str, + importable_classes: List[Any], + pipelines: Any, + is_pipeline_module: bool, + name: str, + torch_dtype: Union[str, torch.dtype], + cached_folder: Union[str, os.PathLike], + **kwargs, +): + # retrieve class objects. + class_obj, _ = get_class_obj_and_candidates( + library_name, + class_name, + importable_classes, + pipelines, + is_pipeline_module, + component_name=name, + cache_dir=cached_folder, + ) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + # Determine library. + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + diffusers_module = importlib.import_module(__name__.split(".")[0]) + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) + + model = None + config_path = cached_folder + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + if is_diffusers_model: + # Load config and then the model on meta. + config, unused_kwargs, commit_hash = class_obj.load_config( + os.path.join(config_path, name), + cache_dir=cached_folder, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("token", None), + revision=kwargs.pop("revision", None), + subfolder=kwargs.pop("subfolder", None), + user_agent=user_agent, + ) + with accelerate.init_empty_weights(): + model = class_obj.from_config(config, **unused_kwargs) + elif is_transformers_model: + config_class = getattr(class_obj, "config_class", None) + if config_class is None: + raise ValueError("`config_class` cannot be None. Please double-check the model.") + + config = config_class.from_pretrained( + cached_folder, + subfolder=name, + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("token", None), + revision=kwargs.pop("revision", None), + user_agent=user_agent, + ) + with accelerate.init_empty_weights(): + model = class_obj(config) + + if model is not None: + model = model.to(dtype=torch_dtype) + return model + + +def _assign_components_to_devices( + module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced" +): + device_ids = list(device_memory.keys()) + device_cycle = device_ids + device_ids[::-1] + device_memory = device_memory.copy() + + device_id_component_mapping = {} + current_device_index = 0 + for component in module_sizes: + device_id = device_cycle[current_device_index % len(device_cycle)] + component_memory = module_sizes[component] + curr_device_memory = device_memory[device_id] + + # If the GPU doesn't fit the current component offload to the CPU. + if component_memory > curr_device_memory: + device_id_component_mapping["cpu"] = [component] + else: + if device_id not in device_id_component_mapping: + device_id_component_mapping[device_id] = [component] + else: + device_id_component_mapping[device_id].append(component) + + # Update the device memory. + device_memory[device_id] -= component_memory + current_device_index += 1 + + return device_id_component_mapping + + +def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): + # To avoid circular import problem. + from diffusers import pipelines + + torch_dtype = kwargs.get("torch_dtype", torch.float32) + + # Load each module in the pipeline on a meta device so that we can derive the device map. + init_empty_modules = {} + for name, (library_name, class_name) in init_dict.items(): + if class_name.startswith("Flax"): + raise ValueError("Flax pipelines are not supported with `device_map`.") + + # Define all importable classes + is_pipeline_module = hasattr(pipelines, library_name) + importable_classes = ALL_IMPORTABLE_CLASSES + loaded_sub_model = None + + # Use passed sub model or load class_name from library_name + if name in passed_class_obj: + # if the model is in a pipeline module, then we load it from the pipeline + # check that passed_class_obj has correct parent class + maybe_raise_or_warn( + library_name, + library, + class_name, + importable_classes, + passed_class_obj, + name, + is_pipeline_module, + ) + with accelerate.init_empty_weights(): + loaded_sub_model = passed_class_obj[name] + + else: + loaded_sub_model = _load_empty_model( + library_name=library_name, + class_name=class_name, + importable_classes=importable_classes, + pipelines=pipelines, + is_pipeline_module=is_pipeline_module, + pipeline_class=pipeline_class, + name=name, + torch_dtype=torch_dtype, + cached_folder=kwargs.get("cached_folder", None), + force_download=kwargs.get("force_download", None), + proxies=kwargs.get("proxies", None), + local_files_only=kwargs.get("local_files_only", None), + token=kwargs.get("token", None), + revision=kwargs.get("revision", None), + ) + + if loaded_sub_model is not None: + init_empty_modules[name] = loaded_sub_model + + # determine device map + # Obtain a sorted dictionary for mapping the model-level components + # to their sizes. + module_sizes = { + module_name: compute_module_sizes(module, dtype=torch_dtype)[""] + for module_name, module in init_empty_modules.items() + if isinstance(module, torch.nn.Module) + } + module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True)) + + # Obtain maximum memory available per device (GPUs only). + max_memory = get_max_memory(max_memory) + max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True)) + max_memory = {k: v for k, v in max_memory.items() if k != "cpu"} + + # Obtain a dictionary mapping the model-level components to the available + # devices based on the maximum memory and the model sizes. + final_device_map = None + if len(max_memory) > 0: + device_id_component_mapping = _assign_components_to_devices( + module_sizes, max_memory, device_mapping_strategy=device_map + ) + + # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}` + final_device_map = {} + for device_id, components in device_id_component_mapping.items(): + for component in components: + final_device_map[component] = device_id + + return final_device_map + + +def load_sub_model( + library_name: str, + class_name: str, + importable_classes: List[Any], + pipelines: Any, + is_pipeline_module: bool, + pipeline_class: Any, + torch_dtype: torch.dtype, + provider: Any, + sess_options: Any, + device_map: Optional[Union[Dict[str, torch.device], str]], + max_memory: Optional[Dict[Union[int, str], Union[int, str]]], + offload_folder: Optional[Union[str, os.PathLike]], + offload_state_dict: bool, + model_variants: Dict[str, str], + name: str, + from_flax: bool, + variant: str, + low_cpu_mem_usage: bool, + cached_folder: Union[str, os.PathLike], +): + """Helper method to load the module `name` from `library_name` and `class_name`""" + + # retrieve class candidates + + class_obj, class_candidates = get_class_obj_and_candidates( + library_name, + class_name, + importable_classes, + pipelines, + is_pipeline_module, + component_name=name, + cache_dir=cached_folder, + ) + + load_method_name = None + # retrieve load method name + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + # if load method name is None, then we have a dummy module -> raise Error + if load_method_name is None: + none_module = class_obj.__module__ + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) + + # add kwargs to loading method + diffusers_module = importlib.import_module(__name__.split(".")[0]) + loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. + # This makes sure that the weights won't be initialized which significantly speeds up loading. + if is_diffusers_model or is_transformers_model: + loading_kwargs["device_map"] = device_map + loading_kwargs["max_memory"] = max_memory + loading_kwargs["offload_folder"] = offload_folder + loading_kwargs["offload_state_dict"] = offload_state_dict + loading_kwargs["variant"] = model_variants.pop(name, None) + + if from_flax: + loading_kwargs["from_flax"] = True + + # the following can be deleted once the minimum required `transformers` version + # is higher than 4.27 + if ( + is_transformers_model + and loading_kwargs["variant"] is not None + and transformers_version < version.parse("4.27.0") + ): + raise ImportError( + f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" + ) + elif is_transformers_model and loading_kwargs["variant"] is None: + loading_kwargs.pop("variant") + + # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` + if not (from_flax and is_transformers_model): + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + else: + loading_kwargs["low_cpu_mem_usage"] = False + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict): + # remove hooks + remove_hook_from_module(loaded_sub_model, recurse=True) + needs_offloading_to_cpu = device_map[""] == "cpu" + + if needs_offloading_to_cpu: + dispatch_model( + loaded_sub_model, + state_dict=loaded_sub_model.state_dict(), + device_map=device_map, + force_hooks=True, + main_device=0, + ) + else: + dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True) + + return loaded_sub_model + + +def _fetch_class_library_tuple(module): + # import it here to avoid circular import + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # register the config from the original module, not the dynamo compiled one + not_compiled_module = _unwrap_model(module) + library = not_compiled_module.__module__.split(".")[0] + + # check if the module is a pipeline module + module_path_items = not_compiled_module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = not_compiled_module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if is_pipeline_module: + library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = not_compiled_module.__module__ + + # retrieve class_name + class_name = not_compiled_module.__class__.__name__ + + return (library, class_name) + + +def _identify_model_variants(folder: str, variant: str, config: dict) -> dict: + model_variants = {} + if variant is not None: + for sub_folder in os.listdir(folder): + folder_path = os.path.join(folder, sub_folder) + is_folder = os.path.isdir(folder_path) and sub_folder in config + variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + if variant_exists: + model_variants[sub_folder] = variant + return model_variants + + +def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline): + custom_class_name = None + if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")): + custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py") + elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile( + os.path.join(folder, f"{config['_class_name'][0]}.py") + ): + custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py") + custom_class_name = config["_class_name"][1] + + return custom_pipeline, custom_class_name + + +def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict): + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy + + pipeline_class = StableDiffusionInpaintPipelineLegacy + + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + + +def _update_init_kwargs_with_connected_pipeline( + init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs +) -> dict: + from .pipeline_utils import DiffusionPipeline + + modelcard = ModelCard.load(os.path.join(folder, "README.md")) + connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS} + + # We don't scheduler argument to match the existing logic: + # https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14 + pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy() + if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1: + for k in pipeline_loading_kwargs: + if "scheduler" in k: + _ = pipeline_loading_kwargs_cp.pop(k) + + def get_connected_passed_kwargs(prefix): + connected_passed_class_obj = { + k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix + } + connected_passed_pipe_kwargs = { + k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix + } + + connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} + return connected_passed_kwargs + + connected_pipes = { + prefix: DiffusionPipeline.from_pretrained( + repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix) + ) + for prefix, repo_id in connected_pipes.items() + if repo_id is not None + } + + for prefix, connected_pipe in connected_pipes.items(): + # add connected pipes to `init_kwargs` with _, e.g. "prior_text_encoder" + init_kwargs.update( + {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} + ) + + return init_kwargs diff --git a/diffusers/src/diffusers/pipelines/pipeline_utils.py b/diffusers/src/diffusers/pipelines/pipeline_utils.py new file mode 100755 index 0000000..ccd1c94 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,1944 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import fnmatch +import importlib +import inspect +import os +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin + +import numpy as np +import PIL.Image +import requests +import torch +from huggingface_hub import ( + ModelCard, + create_repo, + hf_hub_download, + model_info, + snapshot_download, +) +from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args +from packaging import version +from requests.exceptions import HTTPError +from tqdm.auto import tqdm + +from .. import __version__ +from ..configuration_utils import ConfigMixin +from ..models import AutoencoderKL +from ..models.attention_processor import FusedAttnProcessor2_0 +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin +from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from ..utils import ( + CONFIG_NAME, + DEPRECATED_REVISION_ARGS, + BaseOutput, + PushToHubMixin, + deprecate, + is_accelerate_available, + is_accelerate_version, + is_torch_npu_available, + is_torch_version, + logging, + numpy_to_pil, +) +from ..utils.hub_utils import load_or_create_model_card, populate_model_card +from ..utils.torch_utils import is_compiled_module + + +if is_torch_npu_available(): + import torch_npu # noqa: F401 + + +from .pipeline_loading_utils import ( + ALL_IMPORTABLE_CLASSES, + CONNECTED_PIPES_KEYS, + CUSTOM_PIPELINE_FILE_NAME, + LOADABLE_CLASSES, + _fetch_class_library_tuple, + _get_custom_pipeline_class, + _get_final_device_map, + _get_pipeline_class, + _identify_model_variants, + _maybe_raise_warning_for_inpainting, + _resolve_custom_pipeline_and_cls, + _unwrap_model, + _update_init_kwargs_with_connected_pipeline, + is_safetensors_compatible, + load_sub_model, + maybe_raise_or_warn, + variant_compatible_siblings, + warn_deprecated_model_variant, +) + + +if is_accelerate_available(): + import accelerate + + +LIBRARIES = [] +for library in LOADABLE_CLASSES: + LIBRARIES.append(library) + +SUPPORTED_DEVICE_MAP = ["balanced"] + +logger = logging.get_logger(__name__) + + +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class AudioPipelineOutput(BaseOutput): + """ + Output class for audio pipelines. + + Args: + audios (`np.ndarray`) + List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`. + """ + + audios: np.ndarray + + +class DiffusionPipeline(ConfigMixin, PushToHubMixin): + r""" + Base class for all pipelines. + + [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and + provides methods for loading, downloading and saving models. It also includes methods to: + + - move all PyTorch modules to the device of your choice + - enable/disable the progress bar for the denoising iteration + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + - **_optional_components** (`List[str]`) -- List of all optional components that don't have to be passed to the + pipeline to function (should be overridden by subclasses). + """ + + config_name = "model_index.json" + model_cpu_offload_seq = None + hf_device_map = None + _optional_components = [] + _exclude_from_cpu_offload = [] + _load_connected_pipes = False + _is_onnx = False + + def register_modules(self, **kwargs): + for name, module in kwargs.items(): + # retrieve library + if module is None or isinstance(module, (tuple, list)) and module[0] is None: + register_dict = {name: (None, None)} + else: + library, class_name = _fetch_class_library_tuple(module) + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def __setattr__(self, name: str, value: Any): + if name in self.__dict__ and hasattr(self.config, name): + # We need to overwrite the config if name exists in config + if isinstance(getattr(self.config, name), (tuple, list)): + if value is not None and self.config[name][0] is not None: + class_library_tuple = _fetch_class_library_tuple(value) + else: + class_library_tuple = (None, None) + + self.register_to_config(**{name: class_library_tuple}) + else: + self.register_to_config(**{name: value}) + + super().__setattr__(name, value) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = True, + variant: Optional[str] = None, + max_shard_size: Optional[Union[int, str]] = None, + push_to_hub: bool = False, + **kwargs, + ): + """ + Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its + class implements both a save and loading method. The pipeline is easily reloaded using the + [`~DiffusionPipeline.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a pipeline to. Will be created if it doesn't exist. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `None`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain + period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`. + This is to establish a common default size for this argument across different libraries in the Hugging + Face ecosystem (`transformers`, and `accelerate`, for example). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name", None) + model_index_dict.pop("_diffusers_version", None) + model_index_dict.pop("_module", None) + model_index_dict.pop("_name_or_path", None) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + expected_modules, optional_kwargs = self._get_signature_keys(self) + + def is_saveable_module(name, value): + if name not in expected_modules: + return False + if name in self._optional_components and value[0] is None: + return False + return True + + model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + if is_compiled_module(sub_model): + sub_model = _unwrap_model(sub_model) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + if library_name in sys.modules: + library = importlib.import_module(library_name) + else: + logger.info( + f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}" + ) + + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + if save_method_name is None: + logger.warning( + f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved." + ) + # make sure that unsaveable components are not tried to be loaded afterward + self.register_to_config(**{pipeline_component_name: (None, None)}) + continue + + save_method = getattr(sub_model, save_method_name) + + # Call the save method with the argument safe_serialization only if it's supported + save_method_signature = inspect.signature(save_method) + save_method_accept_safe = "safe_serialization" in save_method_signature.parameters + save_method_accept_variant = "variant" in save_method_signature.parameters + save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters + + save_kwargs = {} + if save_method_accept_safe: + save_kwargs["safe_serialization"] = safe_serialization + if save_method_accept_variant: + save_kwargs["variant"] = variant + if save_method_accept_max_shard_size and max_shard_size is not None: + # max_shard_size is expected to not be None in ModelMixin + save_kwargs["max_shard_size"] = max_shard_size + + save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) + + # finally save the config + self.save_config(save_directory) + + if push_to_hub: + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) + model_card = populate_model_card(model_card) + model_card.save(os.path.join(save_directory, "README.md")) + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + def to(self, *args, **kwargs): + r""" + Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the + arguments of `self.to(*args, **kwargs).` + + + + If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, + the returned pipeline is a copy of self with the desired torch.dtype and torch.device. + + + + + Here are the ways to call `to`: + + - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the + specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + + Arguments: + dtype (`torch.dtype`, *optional*): + Returns a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + device (`torch.Device`, *optional*): + Returns a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + silence_dtype_warnings (`str`, *optional*, defaults to `False`): + Whether to omit warnings if the target `dtype` is not compatible with the target `device`. + + Returns: + [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. + """ + dtype = kwargs.pop("dtype", None) + device = kwargs.pop("device", None) + silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) + + dtype_arg = None + device_arg = None + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_arg = args[0] + else: + device_arg = torch.device(args[0]) if args[0] is not None else None + elif len(args) == 2: + if isinstance(args[0], torch.dtype): + raise ValueError( + "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." + ) + device_arg = torch.device(args[0]) if args[0] is not None else None + dtype_arg = args[1] + elif len(args) > 2: + raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") + + if dtype is not None and dtype_arg is not None: + raise ValueError( + "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + dtype = dtype or dtype_arg + + if device is not None and device_arg is not None: + raise ValueError( + "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + device = device or device_arg + + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + + return hasattr(module, "_hf_hook") and ( + isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + or hasattr(module._hf_hook, "hooks") + and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) + ) + + def module_is_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): + return False + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": + raise ValueError( + "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." + ) + + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + if pipeline_is_offloaded and device and torch.device(device).type == "cuda": + logger.warning( + f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." + ) + + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded + for module in modules: + is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + + if is_loaded_in_8bit and dtype is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." + ) + + if is_loaded_in_8bit and device is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." + ) + else: + module.to(device, dtype) + + if ( + module.dtype == torch.float16 + and str(device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + return self + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + custom_pipeline (`str`, *optional*): + + + + 🧪 This is an experimental feature and may change in the future. + + + + Can be either: + + - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom + pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines + the custom pipeline. + - A string, the *file name* of a community pipeline hosted on GitHub under + [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file + names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` + instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the + current main branch of GitHub. + - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory + must contain a file called `pipeline.py` that defines the custom pipeline. + + For more information on how to load and create custom pipelines, please have a look at [Loading and + Adding Custom + Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers + version. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + use_onnx (`bool`, *optional*, defaults to `None`): + If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights + will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is + `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending + with `.onnx` and `.pb`. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + + >>> # Use a different scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.scheduler = scheduler + ``` + """ + # Copy the kwargs to re-use during loading connected pipeline. + kwargs_copied = kwargs.copy() + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) + torch_dtype = kwargs.pop("torch_dtype", None) + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + provider = kwargs.pop("provider", None) + sess_options = kwargs.pop("sess_options", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + use_onnx = kwargs.pop("use_onnx", None) + load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`." + ) + + if device_map is not None and not isinstance(device_map, str): + raise ValueError("`device_map` must be a string.") + + if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP: + raise NotImplementedError( + f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}" + ) + + if device_map is not None and device_map in SUPPORTED_DEVICE_MAP: + if is_accelerate_version("<", "0.28.0"): + raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.") + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + if pretrained_model_name_or_path.count("/") > 1: + raise ValueError( + f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"' + " is neither a valid local path nor a valid repo id. Please check the parameter." + ) + cached_folder = cls.download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + from_flax=from_flax, + use_safetensors=use_safetensors, + use_onnx=use_onnx, + custom_pipeline=custom_pipeline, + custom_revision=custom_revision, + variant=variant, + load_connected_pipeline=load_connected_pipeline, + **kwargs, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.load_config(cached_folder) + + # pop out "_ignore_files" as it is only needed for download + config_dict.pop("_ignore_files", None) + + # 2. Define which model components should load variants + # We retrieve the information by matching whether variant model checkpoints exist in the subfolders. + # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` + # with variant being `"fp16"`. + model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) + + # 3. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls( + folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline + ) + pipeline_class = _get_pipeline_class( + cls, + config=config_dict, + load_connected_pipeline=load_connected_pipeline, + custom_pipeline=custom_pipeline, + class_name=custom_class_name, + cache_dir=cache_dir, + revision=custom_revision, + ) + + if device_map is not None and pipeline_class._load_connected_pipes: + raise NotImplementedError("`device_map` is not yet supported for connected pipelines.") + + # DEPRECATED: To be removed in 1.0.0 + # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded + # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1. + _maybe_raise_warning_for_inpainting( + pipeline_class=pipeline_class, + pretrained_model_name_or_path=pretrained_model_name_or_path, + config=config_dict, + ) + + # 4. Define expected modules given pipeline signature + # and define non-None initialized modules (=`init_kwargs`) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + # define init kwargs and make sure that optional component modules are filtered out + init_kwargs = { + k: init_dict.pop(k) + for k in optional_kwargs + if k in init_dict and k not in pipeline_class._optional_components + } + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + # Special case: safety_checker must be loaded separately when using `from_flax` + if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: + raise NotImplementedError( + "The safety checker cannot be automatically loaded when loading weights `from_flax`." + " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" + " separately if you need it." + ) + + # 5. Throw nice warnings / errors for fast accelerate loading + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) + + # import it here to avoid circular import + from diffusers import pipelines + + # 6. device map delegation + final_device_map = None + if device_map is not None: + final_device_map = _get_final_device_map( + device_map=device_map, + pipeline_class=pipeline_class, + passed_class_obj=passed_class_obj, + init_dict=init_dict, + library=library, + max_memory=max_memory, + torch_dtype=torch_dtype, + cached_folder=cached_folder, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + + # 7. Load each module in the pipeline + current_device_map = None + for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): + # 7.1 device_map shenanigans + if final_device_map is not None and len(final_device_map) > 0: + component_device = final_device_map.get(name, None) + if component_device is not None: + current_device_map = {"": component_device} + else: + current_device_map = None + + # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names + class_name = class_name[4:] if class_name.startswith("Flax") else class_name + + # 7.3 Define all importable classes + is_pipeline_module = hasattr(pipelines, library_name) + importable_classes = ALL_IMPORTABLE_CLASSES + loaded_sub_model = None + + # 7.4 Use passed sub model or load class_name from library_name + if name in passed_class_obj: + # if the model is in a pipeline module, then we load it from the pipeline + # check that passed_class_obj has correct parent class + maybe_raise_or_warn( + library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module + ) + + loaded_sub_model = passed_class_obj[name] + else: + # load sub model + loaded_sub_model = load_sub_model( + library_name=library_name, + class_name=class_name, + importable_classes=importable_classes, + pipelines=pipelines, + is_pipeline_module=is_pipeline_module, + pipeline_class=pipeline_class, + torch_dtype=torch_dtype, + provider=provider, + sess_options=sess_options, + device_map=current_device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + model_variants=model_variants, + name=name, + from_flax=from_flax, + variant=variant, + low_cpu_mem_usage=low_cpu_mem_usage, + cached_folder=cached_folder, + ) + logger.info( + f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." + ) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 8. Handle connected pipelines. + if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): + init_kwargs = _update_init_kwargs_with_connected_pipeline( + init_kwargs=init_kwargs, + passed_pipe_kwargs=passed_pipe_kwargs, + passed_class_objs=passed_class_obj, + folder=cached_folder, + **kwargs_copied, + ) + + # 9. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + # 10. Instantiate the pipeline + model = pipeline_class(**init_kwargs) + + # 11. Save where the model was instantiated from + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + if device_map is not None: + setattr(model, "hf_device_map", final_device_map) + return model + + @property + def name_or_path(self) -> str: + return getattr(self.config, "_name_or_path", None) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def remove_all_hooks(self): + r""" + Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`. + """ + for _, model in self.components.items(): + if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): + accelerate.hooks.remove_hook_from_module(model, recurse=True) + self._all_hooks = [] + + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + + Arguments: + gpu_id (`int`, *optional*): + The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. + device (`torch.Device` or `str`, *optional*, defaults to "cuda"): + The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will + default to "cuda". + """ + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." + ) + + if self.model_cpu_offload_seq is None: + raise ValueError( + "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set." + ) + + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + self.remove_all_hooks() + + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 + self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0) + + device_type = torch_device.type + device = torch.device(f"{device_type}:{self._offload_gpu_id}") + self._offload_device = device + + self.to("cpu", silence_dtype_warnings=True) + device_mod = getattr(torch, device.type, None) + if hasattr(device_mod, "empty_cache") and device_mod.is_available(): + device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} + + self._all_hooks = [] + hook = None + for model_str in self.model_cpu_offload_seq.split("->"): + model = all_model_components.pop(model_str, None) + if not isinstance(model, torch.nn.Module): + continue + + _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook) + self._all_hooks.append(hook) + + # CPU offload models that are not in the seq chain unless they are explicitly excluded + # these models will stay on CPU until maybe_free_model_hooks is called + # some models cannot be in the seq chain because they are iteratively called, such as controlnet + for name, model in all_model_components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if name in self._exclude_from_cpu_offload: + model.to(device) + else: + _, hook = cpu_offload_with_hook(model, device) + self._all_hooks.append(hook) + + def maybe_free_model_hooks(self): + r""" + Function that offloads all components, removes all model hooks that were added when using + `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function + is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it + functions correctly when applying enable_model_cpu_offload. + """ + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: + # `enable_model_cpu_offload` has not be called, so silently do nothing + return + + # make sure the model is in the same state as before calling it + self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) + + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state + dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU + and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` + method called. Offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + + Arguments: + gpu_id (`int`, *optional*): + The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. + device (`torch.Device` or `str`, *optional*, defaults to "cuda"): + The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will + default to "cuda". + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + self.remove_all_hooks() + + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." + ) + + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 + self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0) + + device_type = torch_device.type + device = torch.device(f"{device_type}:{self._offload_gpu_id}") + self._offload_device = device + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + device_mod = getattr(torch, self.device.type, None) + if hasattr(device_mod, "empty_cache") and device_mod.is_available(): + device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if name in self._exclude_from_cpu_offload: + model.to(device) + else: + # make sure to offload buffers if not all high level weights + # are of type nn.Module + offload_buffers = len(model._parameters) > 0 + cpu_offload(model, device, offload_buffers=offload_buffers) + + def reset_device_map(self): + r""" + Resets the device maps (if any) to None. + """ + if self.hf_device_map is None: + return + else: + self.remove_all_hooks() + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module): + component.to("cpu") + self.hf_device_map = None + + @classmethod + @validate_hf_hub_args + def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: + r""" + Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights. + + Parameters: + pretrained_model_name (`str` or `os.PathLike`, *optional*): + A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + custom_pipeline (`str`, *optional*): + Can be either: + + - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained + pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines + the custom pipeline. + + - A string, the *file name* of a community pipeline hosted on GitHub under + [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file + names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` + instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the + current `main` branch of GitHub. + + - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory + must contain a file called `pipeline.py` that defines the custom pipeline. + + + + 🧪 This is an experimental feature and may change in the future. + + + + For more information on how to load and create custom pipelines, take a look at [How to contribute a + community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline). + + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + use_onnx (`bool`, *optional*, defaults to `False`): + If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights + will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is + `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending + with `.onnx` and `.pb`. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This + option should only be set to `True` for repositories you trust and in which you have read the code, as + it will execute code present on the Hub on your local machine. + + Returns: + `os.PathLike`: + A path to the downloaded pipeline. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. + + + + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + use_onnx = kwargs.pop("use_onnx", None) + load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + trust_remote_code = kwargs.pop("trust_remote_code", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + allow_patterns = None + ignore_patterns = None + + model_info_call_error: Optional[Exception] = None + if not local_files_only: + try: + info = model_info(pretrained_model_name, token=token, revision=revision) + except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: + logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") + local_files_only = True + model_info_call_error = e # save error to reraise it if model is not cached locally + + if not local_files_only: + config_file = hf_hub_download( + pretrained_model_name, + cls.config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + token=token, + ) + + config_dict = cls._dict_from_json_file(config_file) + ignore_filenames = config_dict.pop("_ignore_files", []) + + # retrieve all folder_names that contain relevant files + folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] + + filenames = {sibling.rfilename for sibling in info.siblings} + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # optionally create a custom component <> custom file mapping + custom_components = {} + for component in folder_names: + module_candidate = config_dict[component][0] + + if module_candidate is None or not isinstance(module_candidate, str): + continue + + # We compute candidate file path on the Hub. Do not use `os.path.join`. + candidate_file = f"{component}/{module_candidate}.py" + + if candidate_file in filenames: + custom_components[component] = module_candidate + elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): + raise ValueError( + f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." + ) + + if len(variant_filenames) == 0 and variant is not None: + deprecation_message = ( + f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" + "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" + "modeling files is deprecated." + ) + deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) + + # remove ignored filenames + model_filenames = set(model_filenames) - set(ignore_filenames) + variant_filenames = set(variant_filenames) - set(ignore_filenames) + + # if the whole pipeline is cached we don't have to ping the Hub + if revision in DEPRECATED_REVISION_ARGS and version.parse( + version.parse(__version__).base_version + ) >= version.parse("0.22.0"): + warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) + + model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} + + custom_class_name = None + if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)): + custom_pipeline = config_dict["_class_name"][0] + custom_class_name = config_dict["_class_name"][1] + + # all filenames compatible with variant will be added + allow_patterns = list(model_filenames) + + # allow all patterns from non-model folders + # this enables downloading schedulers, tokenizers, ... + allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] + # add custom component files + allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] + # add custom pipeline file + allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] + # also allow downloading config.json files with the model + allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] + + allow_patterns += [ + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + cls.config_name, + CUSTOM_PIPELINE_FILE_NAME, + ] + + load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames + load_components_from_hub = len(custom_components) > 0 + + if load_pipe_from_hub and not trust_remote_code: + raise ValueError( + f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + + if load_components_from_hub and not trust_remote_code: + raise ValueError( + f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly " + f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + + # retrieve passed components that should not be downloaded + pipeline_class = _get_pipeline_class( + cls, + config_dict, + load_connected_pipeline=load_connected_pipeline, + custom_pipeline=custom_pipeline, + repo_id=pretrained_model_name if load_pipe_from_hub else None, + hub_revision=revision, + class_name=custom_class_name, + cache_dir=cache_dir, + revision=custom_revision, + ) + expected_components, _ = cls._get_signature_keys(pipeline_class) + passed_components = [k for k in expected_components if k in kwargs] + + if ( + use_safetensors + and not allow_pickle + and not is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ) + ): + raise EnvironmentError( + f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" + ) + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] + elif use_safetensors and is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ): + ignore_patterns = ["*.bin", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} + if ( + len(safetensors_variant_filenames) > 0 + and safetensors_model_filenames != safetensors_variant_filenames + ): + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + else: + ignore_patterns = ["*.safetensors", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + + # Don't download any objects that are passed + allow_patterns = [ + p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) + ] + + if pipeline_class._load_connected_pipes: + allow_patterns.append("README.md") + + # Don't download index files of forbidden patterns either + ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] + re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] + re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] + + expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)] + expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)] + + snapshot_folder = Path(config_file).parent + pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files) + + if pipeline_is_cached and not force_download: + # if the pipeline is cached, we can directly return it + # else call snapshot_download + return snapshot_folder + + user_agent = {"pipeline_class": cls.__name__} + if custom_pipeline is not None and not custom_pipeline.endswith(".py"): + user_agent["custom_pipeline"] = custom_pipeline + + # download all allow_patterns - ignore_patterns + try: + cached_folder = snapshot_download( + pretrained_model_name, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + + # retrieve pipeline class from local file + cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) + cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None + + if pipeline_class is not None and pipeline_class._load_connected_pipes: + modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) + connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], []) + for connected_pipe_repo_id in connected_pipes: + download_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "variant": variant, + "use_safetensors": use_safetensors, + } + DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs) + + return cached_folder + + except FileNotFoundError: + # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. + # This can happen in two cases: + # 1. If the user passed `local_files_only=True` => we raise the error directly + # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error + if model_info_call_error is None: + # 1. user passed `local_files_only=True` + raise + else: + # 2. we forced `local_files_only=True` when `model_info` failed + raise EnvironmentError( + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" + " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" + " above." + ) from model_info_call_error + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + optional_names = list(optional_parameters) + for name in optional_names: + if name in cls._optional_components: + expected_modules.add(name) + optional_parameters.remove(name) + + return expected_modules, optional_parameters + + @classmethod + def _get_signature_types(cls): + signature_types = {} + for k, v in inspect.signature(cls.__init__).parameters.items(): + if inspect.isclass(v.annotation): + signature_types[k] = (v.annotation,) + elif get_origin(v.annotation) == Union: + signature_types[k] = get_args(v.annotation) + else: + logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.") + return signature_types + + @property + def components(self) -> Dict[str, Any]: + r""" + The `self.components` property can be useful to run different pipelines with the same weights and + configurations without reallocating additional memory. + + Returns (`dict`): + A dictionary containing all the modules needed to initialize the pipeline. + + Examples: + + ```py + >>> from diffusers import ( + ... StableDiffusionPipeline, + ... StableDiffusionImg2ImgPipeline, + ... StableDiffusionInpaintPipeline, + ... ) + + >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) + >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) + ``` + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components.keys()} are defined." + ) + + return components + + @staticmethod + def numpy_to_pil(images): + """ + Convert a NumPy image or a batch of images to a PIL image. + """ + return numpy_to_pil(images) + + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this + option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed + up during training is not guaranteed. + + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + >>> # Workaround for not accepting attention shape using VAE for Flash Attention + >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + fn_recursive_set_mem_eff(module) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor + in slices to compute attention in several steps. For more than one attention head, the computation is performed + sequentially over each head. This is useful to save some memory in exchange for a small speed decrease. + + + + ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch + 2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable + this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs! + + + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + + Examples: + + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", + ... torch_dtype=torch.float16, + ... use_safetensors=True, + ... ) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> pipe.enable_attention_slicing() + >>> image = pipe(prompt).images[0] + ``` + """ + self.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is + computed in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def set_attention_slice(self, slice_size: Optional[int]): + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")] + + for module in modules: + module.set_attention_slice(slice_size) + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + r""" + Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing + pipeline components without reallocating additional memory. + + Arguments: + pipeline (`DiffusionPipeline`): + The pipeline from which to create a new pipeline. + + Returns: + `DiffusionPipeline`: + A new pipeline with the same weights and configurations as `pipeline`. + + Examples: + + ```py + >>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe) + ``` + """ + + original_config = dict(pipeline.config) + torch_dtype = kwargs.pop("torch_dtype", None) + + # derive the pipeline class to instantiate + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + + if custom_pipeline is not None: + pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision) + else: + pipeline_class = cls + + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__` + # e.g. `image_encoder` for StableDiffusionPipeline + parameters = inspect.signature(cls.__init__).parameters + true_optional_modules = set( + {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules} + ) + + # get the class of each component based on its type hint + # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode} + component_types = pipeline_class._get_signature_types() + + pretrained_model_name_or_path = original_config.pop("_name_or_path", None) + # allow users pass modules in `kwargs` to override the original pipeline's components + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + original_class_obj = {} + for name, component in pipeline.components.items(): + if name in expected_modules and name not in passed_class_obj: + # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature + if ( + not isinstance(component, ModelMixin) + or type(component) in component_types[name] + or (component is None and name in cls._optional_components) + ): + original_class_obj[name] = component + else: + logger.warning( + f"component {name} is not switched over to new pipeline because type does not match the expected." + f" {name} is {type(component)} while the new pipeline expect {component_types[name]}." + f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`" + ) + + # allow users pass optional kwargs to override the original pipelines config attribute + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + original_pipe_kwargs = { + k: original_config[k] + for k in original_config.keys() + if k in optional_kwargs and k not in passed_pipe_kwargs + } + + # config attribute that were not expected by pipeline is stored as its private attribute + # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config) + # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline + additional_pipe_kwargs = [ + k[1:] + for k in original_config.keys() + if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs + ] + for k in additional_pipe_kwargs: + original_pipe_kwargs[k] = original_config.pop(f"_{k}") + + pipeline_kwargs = { + **passed_class_obj, + **original_class_obj, + **passed_pipe_kwargs, + **original_pipe_kwargs, + **kwargs, + } + + # store unused config as private attribute in the new pipeline + unused_original_config = { + f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs + } + + missing_modules = ( + set(expected_modules) + - set(pipeline._optional_components) + - set(pipeline_kwargs.keys()) + - set(true_optional_modules) + ) + + if len(missing_modules) > 0: + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" + ) + + new_pipeline = pipeline_class(**pipeline_kwargs) + if pretrained_model_name_or_path is not None: + new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path) + new_pipeline.register_to_config(**unused_original_config) + + if torch_dtype is not None: + new_pipeline.to(dtype=torch_dtype) + + return new_pipeline + + +class StableDiffusionMixin: + r""" + Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) + """ + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False diff --git a/diffusers/src/diffusers/pipelines/pixart_alpha/__init__.py b/diffusers/src/diffusers/pipelines/pixart_alpha/__init__.py new file mode 100755 index 0000000..7cb870f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pixart_alpha/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"] + _import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, + PixArtAlphaPipeline, + ) + from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py new file mode 100755 index 0000000..5b220df --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -0,0 +1,969 @@ +# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtAlphaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. + >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + +ASPECT_RATIO_256_BIN = { + "0.25": [128.0, 512.0], + "0.28": [128.0, 464.0], + "0.32": [144.0, 448.0], + "0.33": [144.0, 432.0], + "0.35": [144.0, 416.0], + "0.4": [160.0, 400.0], + "0.42": [160.0, 384.0], + "0.48": [176.0, 368.0], + "0.5": [176.0, 352.0], + "0.52": [176.0, 336.0], + "0.57": [192.0, 336.0], + "0.6": [192.0, 320.0], + "0.68": [208.0, 304.0], + "0.72": [208.0, 288.0], + "0.78": [224.0, 288.0], + "0.82": [224.0, 272.0], + "0.88": [240.0, 272.0], + "0.94": [240.0, 256.0], + "1.0": [256.0, 256.0], + "1.07": [256.0, 240.0], + "1.13": [272.0, 240.0], + "1.21": [272.0, 224.0], + "1.29": [288.0, 224.0], + "1.38": [288.0, 208.0], + "1.46": [304.0, 208.0], + "1.67": [320.0, 192.0], + "1.75": [336.0, 192.0], + "2.0": [352.0, 176.0], + "2.09": [368.0, 176.0], + "2.4": [384.0, 160.0], + "2.5": [400.0, 160.0], + "3.0": [432.0, 144.0], + "4.0": [512.0, 128.0], +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtAlphaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Alpha. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py new file mode 100755 index 0000000..69f0289 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -0,0 +1,880 @@ +# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_2048_BIN = { + "0.25": [1024.0, 4096.0], + "0.26": [1024.0, 3968.0], + "0.27": [1024.0, 3840.0], + "0.28": [1024.0, 3712.0], + "0.32": [1152.0, 3584.0], + "0.33": [1152.0, 3456.0], + "0.35": [1152.0, 3328.0], + "0.4": [1280.0, 3200.0], + "0.42": [1280.0, 3072.0], + "0.48": [1408.0, 2944.0], + "0.5": [1408.0, 2816.0], + "0.52": [1408.0, 2688.0], + "0.57": [1536.0, 2688.0], + "0.6": [1536.0, 2560.0], + "0.68": [1664.0, 2432.0], + "0.72": [1664.0, 2304.0], + "0.78": [1792.0, 2304.0], + "0.82": [1792.0, 2176.0], + "0.88": [1920.0, 2176.0], + "0.94": [1920.0, 2048.0], + "1.0": [2048.0, 2048.0], + "1.07": [2048.0, 1920.0], + "1.13": [2176.0, 1920.0], + "1.21": [2176.0, 1792.0], + "1.29": [2304.0, 1792.0], + "1.38": [2304.0, 1664.0], + "1.46": [2432.0, 1664.0], + "1.67": [2560.0, 1536.0], + "1.75": [2688.0, 1536.0], + "2.0": [2816.0, 1408.0], + "2.09": [2944.0, 1408.0], + "2.4": [3072.0, 1280.0], + "2.5": [3200.0, 1280.0], + "2.89": [3328.0, 1152.0], + "3.0": [3456.0, 1152.0], + "3.11": [3584.0, 1152.0], + "3.62": [3712.0, 1024.0], + "3.75": [3840.0, 1024.0], + "3.88": [3968.0, 1024.0], + "4.0": [4096.0, 1024.0], +} + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtSigmaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too. + >>> pipe = PixArtSigmaPipeline.from_pretrained( + ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 + ... ) + >>> # Enable memory optimizations. + >>> # pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtSigmaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Sigma. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py new file mode 100755 index 0000000..70f5b1a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"] + _import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py new file mode 100755 index 0000000..3499129 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class SemanticStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`List[bool]`) + List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] diff --git a/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py new file mode 100755 index 0000000..6f83071 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -0,0 +1,722 @@ +import inspect +from itertools import repeat +from typing import Callable, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import SemanticStableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with latent editing. + + This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass + documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular + device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`Q16SafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_embeddings: Optional[torch.Tensor] = None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, + edit_warmup_steps: Optional[Union[int, List[int]]] = 10, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, + edit_momentum_scale: Optional[float] = 0.1, + edit_mom_beta: Optional[float] = 0.4, + edit_weights: Optional[List[float]] = None, + sem_guidance: Optional[List[torch.Tensor]] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + editing_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to use for semantic guidance. Semantic guidance is disabled by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + editing_prompt_embeddings (`torch.Tensor`, *optional*): + Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be + specified via `reverse_editing_direction`. + reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): + Guidance scale for semantic guidance. If provided as a list, values should correspond to + `editing_prompt`. + edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which semantic guidance is not applied. Momentum is + calculated for those steps and applied once all warmup periods are over. + edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which semantic guidance is longer applied. + edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): + Threshold of semantic guidance. + edit_momentum_scale (`float`, *optional*, defaults to 0.1): + Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0, + momentum is disabled. Momentum is already built up during warmup (for diffusion steps smaller than + `sld_warmup_steps`). Momentum is only added to latent guidance once all warmup periods are finished. + edit_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous + momentum is kept. Momentum is already built up during warmup (for diffusion steps smaller than + `edit_warmup_steps`). + edit_weights (`List[float]`, *optional*, defaults to `None`): + Indicates how much each individual concept should influence the overall guidance. If no weights are + provided all concepts are applied equally. + sem_guidance (`List[torch.Tensor]`, *optional*): + List of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + + Examples: + + ```py + >>> import torch + >>> from diffusers import SemanticStableDiffusionPipeline + + >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> out = pipe( + ... prompt="a photo of the face of a woman", + ... num_images_per_prompt=1, + ... guidance_scale=7, + ... editing_prompt=[ + ... "smiling, smile", # Concepts to apply + ... "glasses, wearing glasses", + ... "curls, wavy hair, curly hair", + ... "beard, full beard, mustache", + ... ], + ... reverse_editing_direction=[ + ... False, + ... False, + ... False, + ... False, + ... ], # Direction of guidance i.e. increase all concepts + ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept + ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept + ... edit_threshold=[ + ... 0.99, + ... 0.975, + ... 0.925, + ... 0.96, + ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions + ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance + ... edit_mom_beta=0.6, # Momentum beta + ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other + ... ) + >>> image = out.images[0] + ``` + + Returns: + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images and the second element + is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" + (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeddings is not None: + enable_edit_guidance = True + enabled_editing_prompts = editing_prompt_embeddings.shape[0] + else: + enabled_editing_prompts = 0 + enable_edit_guidance = False + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if enable_edit_guidance: + # get safety text embeddings + if editing_prompt_embeddings is None: + edit_concepts_input = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + + edit_concepts_input_ids = edit_concepts_input.input_ids + + if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode( + edit_concepts_input_ids[:, self.tokenizer.model_max_length :] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] + edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0] + else: + edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed_edit, seq_len_edit, _ = edit_concepts.shape + edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) + edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if enable_edit_guidance: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # get the initial random noise unless the user supplied it + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Initialize edit_momentum to None + edit_momentum = None + + self.uncond_estimates = None + self.text_estimates = None + self.edit_estimates = None + self.sem_guidance = None + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + noise_pred_edit_concepts = noise_pred_out[2:] + + # default text guidance + noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) + # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) + + if self.uncond_estimates is None: + self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) + self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() + + if self.text_estimates is None: + self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + self.text_estimates[i] = noise_pred_text.detach().cpu() + + if self.edit_estimates is None and enable_edit_guidance: + self.edit_estimates = torch.zeros( + (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) + ) + + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + + if edit_momentum is None: + edit_momentum = torch.zeros_like(noise_guidance) + + if enable_edit_guidance: + concept_weights = torch.zeros( + (len(noise_pred_edit_concepts), noise_guidance.shape[0]), + device=device, + dtype=noise_guidance.dtype, + ) + noise_guidance_edit = torch.zeros( + (len(noise_pred_edit_concepts), *noise_guidance.shape), + device=device, + dtype=noise_guidance.dtype, + ) + # noise_guidance_edit = torch.zeros_like(noise_guidance) + warmup_inds = [] + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + self.edit_estimates[i, c] = noise_pred_edit_concept + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + if edit_weights: + edit_weight_c = edit_weights[c] + else: + edit_weight_c = 1.0 + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + if i >= edit_warmup_steps_c: + warmup_inds.append(c) + if i >= edit_cooldown_steps_c: + noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + + tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + concept_weights[c, :] = tmp_weights + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp.dtype == torch.float32: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp.dtype) + + noise_guidance_edit_tmp = torch.where( + torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp + + # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp + + warmup_inds = torch.tensor(warmup_inds).to(device) + if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: + concept_weights = concept_weights.to("cpu") # Offload to cpu + noise_guidance_edit = noise_guidance_edit.to("cpu") + + concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds) + concept_weights_tmp = torch.where( + concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp + ) + concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) + # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) + + noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) + noise_guidance_edit_tmp = torch.einsum( + "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp + ) + noise_guidance = noise_guidance + noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() + + del noise_guidance_edit_tmp + del concept_weights_tmp + concept_weights = concept_weights.to(device) + noise_guidance_edit = noise_guidance_edit.to(device) + + concept_weights = torch.where( + concept_weights < 0, torch.zeros_like(concept_weights), concept_weights + ) + + concept_weights = torch.nan_to_num(concept_weights) + + noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) + noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device) + + noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum + + edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit + + if warmup_inds.shape[0] == len(noise_pred_edit_concepts): + noise_guidance = noise_guidance + noise_guidance_edit + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + if sem_guidance is not None: + edit_guidance = sem_guidance[i].to(device) + noise_guidance = noise_guidance + edit_guidance + + noise_pred = noise_pred_uncond + noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/shap_e/__init__.py b/diffusers/src/diffusers/pipelines/shap_e/__init__.py new file mode 100755 index 0000000..4ed563c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/shap_e/__init__.py @@ -0,0 +1,71 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["camera"] = ["create_pan_cameras"] + _import_structure["pipeline_shap_e"] = ["ShapEPipeline"] + _import_structure["pipeline_shap_e_img2img"] = ["ShapEImg2ImgPipeline"] + _import_structure["renderer"] = [ + "BoundingBoxVolume", + "ImportanceRaySampler", + "MLPNeRFModelOutput", + "MLPNeRSTFModel", + "ShapEParamsProjModel", + "ShapERenderer", + "StratifiedRaySampler", + "VoidNeRFModel", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .camera import create_pan_cameras + from .pipeline_shap_e import ShapEPipeline + from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline + from .renderer import ( + BoundingBoxVolume, + ImportanceRaySampler, + MLPNeRFModelOutput, + MLPNeRSTFModel, + ShapEParamsProjModel, + ShapERenderer, + StratifiedRaySampler, + VoidNeRFModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/shap_e/camera.py b/diffusers/src/diffusers/pipelines/shap_e/camera.py new file mode 100755 index 0000000..d4b94c3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/shap_e/camera.py @@ -0,0 +1,147 @@ +# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import torch + + +@dataclass +class DifferentiableProjectiveCamera: + """ + Implements a batch, differentiable, standard pinhole camera + """ + + origin: torch.Tensor # [batch_size x 3] + x: torch.Tensor # [batch_size x 3] + y: torch.Tensor # [batch_size x 3] + z: torch.Tensor # [batch_size x 3] + width: int + height: int + x_fov: float + y_fov: float + shape: Tuple[int] + + def __post_init__(self): + assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] + assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 + assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 + + def resolution(self): + return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) + + def fov(self): + return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) + + def get_image_coords(self) -> torch.Tensor: + """ + :return: coords of shape (width * height, 2) + """ + pixel_indices = torch.arange(self.height * self.width) + coords = torch.stack( + [ + pixel_indices % self.width, + torch.div(pixel_indices, self.width, rounding_mode="trunc"), + ], + axis=1, + ) + return coords + + @property + def camera_rays(self): + batch_size, *inner_shape = self.shape + inner_batch_size = int(np.prod(inner_shape)) + + coords = self.get_image_coords() + coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) + rays = self.get_camera_rays(coords) + + rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) + + return rays + + def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: + batch_size, *shape, n_coords = coords.shape + assert n_coords == 2 + assert batch_size == self.origin.shape[0] + + flat = coords.view(batch_size, -1, 2) + + res = self.resolution() + fov = self.fov() + + fracs = (flat.float() / (res - 1)) * 2 - 1 + fracs = fracs * torch.tan(fov / 2) + + fracs = fracs.view(batch_size, -1, 2) + directions = ( + self.z.view(batch_size, 1, 3) + + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + rays = torch.stack( + [ + torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), + directions, + ], + dim=2, + ) + return rays.view(batch_size, *shape, 2, 3) + + def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": + """ + Creates a new camera for the resized view assuming the aspect ratio does not change. + """ + assert width * self.height == height * self.width, "The aspect ratio should not change." + return DifferentiableProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=width, + height=height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + + +def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: + origins = [] + xs = [] + ys = [] + zs = [] + for theta in np.linspace(0, 2 * np.pi, num=20): + z = np.array([np.sin(theta), np.cos(theta), -0.5]) + z /= np.sqrt(np.sum(z**2)) + origin = -z * 4 + x = np.array([np.cos(theta), -np.sin(theta), 0.0]) + y = np.cross(z, x) + origins.append(origin) + xs.append(x) + ys.append(y) + zs.append(z) + return DifferentiableProjectiveCamera( + origin=torch.from_numpy(np.stack(origins, axis=0)).float(), + x=torch.from_numpy(np.stack(xs, axis=0)).float(), + y=torch.from_numpy(np.stack(ys, axis=0)).float(), + z=torch.from_numpy(np.stack(zs, axis=0)).float(), + width=size, + height=size, + x_fov=0.7, + y_fov=0.7, + shape=(1, len(xs)), + ) diff --git a/diffusers/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/diffusers/src/diffusers/pipelines/shap_e/pipeline_shap_e.py new file mode 100755 index 0000000..f87f28e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -0,0 +1,334 @@ +# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import PriorTransformer +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .renderer import ShapERenderer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 15.0 + >>> prompt = "a shark" + + >>> images = pipe( + ... prompt, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "shark_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`]. + + Args: + images (`torch.Tensor`) + A list of images for 3D rendering. + """ + + images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]] + + +class ShapEPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with the NeRF method. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with the `prior` model to generate image embedding. + shap_e_renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP to create 3D objects with the NeRF + rendering method. + """ + + model_cpu_offload_seq = "text_encoder->prior" + _exclude_from_cpu_offload = ["shap_e_renderer"] + + def __init__( + self, + prior: PriorTransformer, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: HeunDiscreteScheduler, + shap_e_renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + shap_e_renderer=shap_e_renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + len(prompt) if isinstance(prompt, list) else 1 + + # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file + self.tokenizer.pad_token_id = 0 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + prompt_embeds = text_encoder_output.text_embeds + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # in Shap-E it normalize the prompt_embeds and then later rescale it + prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # Rescale the features to have unit variance + prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds + + return prompt_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: Optional[str] = "pil", # pil, np, latent, mesh + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + frame_size (`int`, *optional*, default to 64): + The width and height of each image frame of the generated 3D output. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`), `"latent"` (`torch.Tensor`), or mesh ([`MeshDecoderOutput`]). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain + tuple. + + Examples: + + Returns: + [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=prompt_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + # Offload all models + self.maybe_free_model_hooks() + + if output_type not in ["np", "pil", "latent", "mesh"]: + raise ValueError( + f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" + ) + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + if output_type == "mesh": + for i, latent in enumerate(latents): + mesh = self.shap_e_renderer.decode_to_mesh( + latent[None, :], + device, + ) + images.append(mesh) + + else: + # np, pil + for i, latent in enumerate(latents): + image = self.shap_e_renderer.decode_to_image( + latent[None, :], + device, + size=frame_size, + ) + images.append(image) + + images = torch.stack(images) + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/diffusers/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/diffusers/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py new file mode 100755 index 0000000..7cc145e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -0,0 +1,321 @@ +# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModel + +from ...models import PriorTransformer +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .renderer import ShapERenderer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e-img2img" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 3.0 + >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png" + >>> image = load_image(image_url).convert("RGB") + + >>> images = pipe( + ... image, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "corgi_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`]. + + Args: + images (`torch.Tensor`) + A list of images for 3D rendering. + """ + + images: Union[PIL.Image.Image, np.ndarray] + + +class ShapEImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with the NeRF method from an image. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`~transformers.CLIPVisionModel`]): + Frozen image-encoder. + image_processor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to process images. + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with the `prior` model to generate image embedding. + shap_e_renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP to create 3D objects with the NeRF + rendering method. + """ + + model_cpu_offload_seq = "image_encoder->prior" + _exclude_from_cpu_offload = ["shap_e_renderer"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + scheduler: HeunDiscreteScheduler, + shap_e_renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + image_encoder=image_encoder, + image_processor=image_processor, + scheduler=scheduler, + shap_e_renderer=shap_e_renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_image( + self, + image, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + if isinstance(image, List) and isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if not isinstance(image, torch.Tensor): + image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0) + + image = image.to(dtype=self.image_encoder.dtype, device=device) + + image_embeds = self.image_encoder(image)["last_hidden_state"] + image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256 + + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image]], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: Optional[str] = "pil", # pil, np, latent, mesh + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can also accept image + latents as image, but if passing latents directly it is not encoded again. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + frame_size (`int`, *optional*, default to 64): + The width and height of each image frame of the generated 3D output. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`), `"latent"` (`torch.Tensor`), or mesh ([`MeshDecoderOutput`]). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain + tuple. + + Examples: + + Returns: + [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)): + batch_size = len(image) + else: + raise ValueError( + f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}" + ) + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + if latents is None: + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=image_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + if output_type not in ["np", "pil", "latent", "mesh"]: + raise ValueError( + f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" + ) + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + if output_type == "mesh": + for i, latent in enumerate(latents): + mesh = self.shap_e_renderer.decode_to_mesh( + latent[None, :], + device, + ) + images.append(mesh) + + else: + # np, pil + for i, latent in enumerate(latents): + image = self.shap_e_renderer.decode_to_image( + latent[None, :], + device, + size=frame_size, + ) + images.append(image) + + images = torch.stack(images) + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/diffusers/src/diffusers/pipelines/shap_e/renderer.py b/diffusers/src/diffusers/pipelines/shap_e/renderer.py new file mode 100755 index 0000000..9d9f9d9 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/shap_e/renderer.py @@ -0,0 +1,1050 @@ +# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...utils import BaseOutput +from .camera import create_pan_cameras + + +def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: + r""" + Sample from the given discrete probability distribution with replacement. + + The i-th bin is assumed to have mass pmf[i]. + + Args: + pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() + n_samples: number of samples + + Return: + indices sampled with replacement + """ + + *shape, support_size, last_dim = pmf.shape + assert last_dim == 1 + + cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) + inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) + + return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) + + +def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: + """ + Concatenate x and its positional encodings, following NeRF. + + Reference: https://arxiv.org/pdf/2210.04628.pdf + """ + if min_deg == max_deg: + return x + + scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) + *shape, dim = x.shape + xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) + assert xb.shape[-1] == dim * (max_deg - min_deg) + emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() + return torch.cat([x, emb], dim=-1) + + +def encode_position(position): + return posenc_nerf(position, min_deg=0, max_deg=15) + + +def encode_direction(position, direction=None): + if direction is None: + return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) + else: + return posenc_nerf(direction, min_deg=0, max_deg=8) + + +def _sanitize_name(x: str) -> str: + return x.replace(".", "__") + + +def integrate_samples(volume_range, ts, density, channels): + r""" + Function integrating the model output. + + Args: + volume_range: Specifies the integral range [t0, t1] + ts: timesteps + density: torch.Tensor [batch_size, *shape, n_samples, 1] + channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] + returns: + channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density + *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume + ) + """ + + # 1. Calculate the weights + _, _, dt = volume_range.partition(ts) + ddensity = density * dt + + mass = torch.cumsum(ddensity, dim=-2) + transmittance = torch.exp(-mass[..., -1, :]) + + alphas = 1.0 - torch.exp(-ddensity) + Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) + # This is the probability of light hitting and reflecting off of + # something at depth [..., i, :]. + weights = alphas * Ts + + # 2. Integrate channels + channels = torch.sum(channels * weights, dim=-2) + + return channels, weights, transmittance + + +def volume_query_points(volume, grid_size): + indices = torch.arange(grid_size**3, device=volume.bbox_min.device) + zs = indices % grid_size + ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size + xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size + combined = torch.stack([xs, ys, zs], dim=1) + return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min + + +def _convert_srgb_to_linear(u: torch.Tensor): + return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) + + +def _create_flat_edge_indices( + flat_cube_indices: torch.Tensor, + grid_size: Tuple[int, int, int], +): + num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] + y_offset = num_xs + num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] + z_offset = num_xs + num_ys + return torch.stack( + [ + # Edges spanning x-axis. + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + # Edges spanning y-axis. + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + # Edges spanning z-axis. + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ], + dim=-1, + ) + + +class VoidNeRFModel(nn.Module): + """ + Implements the default empty space model where all queries are rendered as background. + """ + + def __init__(self, background, channel_scale=255.0): + super().__init__() + background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) + + self.register_buffer("background", background) + + def forward(self, position): + background = self.background[None].to(position.device) + + shape = position.shape[:-1] + ones = [1] * (len(shape) - 1) + n_channels = background.shape[-1] + background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) + + return background + + +@dataclass +class VolumeRange: + t0: torch.Tensor + t1: torch.Tensor + intersected: torch.Tensor + + def __post_init__(self): + assert self.t0.shape == self.t1.shape == self.intersected.shape + + def partition(self, ts): + """ + Partitions t0 and t1 into n_samples intervals. + + Args: + ts: [batch_size, *shape, n_samples, 1] + + Return: + + lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, + *shape, n_samples, 1] + + where + ts \\in [lower, upper] deltas = upper - lower + """ + + mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 + lower = torch.cat([self.t0[..., None, :], mids], dim=-2) + upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) + delta = upper - lower + assert lower.shape == upper.shape == delta.shape == ts.shape + return lower, upper, delta + + +class BoundingBoxVolume(nn.Module): + """ + Axis-aligned bounding box defined by the two opposite corners. + """ + + def __init__( + self, + *, + bbox_min, + bbox_max, + min_dist: float = 0.0, + min_t_range: float = 1e-3, + ): + """ + Args: + bbox_min: the left/bottommost corner of the bounding box + bbox_max: the other corner of the bounding box + min_dist: all rays should start at least this distance away from the origin. + """ + super().__init__() + + self.min_dist = min_dist + self.min_t_range = min_t_range + + self.bbox_min = torch.tensor(bbox_min) + self.bbox_max = torch.tensor(bbox_max) + self.bbox = torch.stack([self.bbox_min, self.bbox_max]) + assert self.bbox.shape == (2, 3) + assert min_dist >= 0.0 + assert min_t_range > 0.0 + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: Optional[torch.Tensor] = None, + epsilon=1e-6, + ): + """ + Args: + origin: [batch_size, *shape, 3] + direction: [batch_size, *shape, 3] + t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + params: Optional meta parameters in case Volume is parametric + epsilon: to stabilize calculations + + Return: + A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with + the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to + be on the boundary of the volume. + """ + + batch_size, *shape, _ = origin.shape + ones = [1] * len(shape) + bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) + + def _safe_divide(a, b, epsilon=1e-6): + return a / torch.where(b < 0, b - epsilon, b + epsilon) + + ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) + + # Cases to think about: + # + # 1. t1 <= t0: the ray does not pass through the AABB. + # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. + # 3. t0 <= 0 <= t1: the ray starts from inside the BB + # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. + # + # 1 and 4 are clearly handled from t0 < t1 below. + # Making t0 at least min_dist (>= 0) takes care of 2 and 3. + t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) + t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values + assert t0.shape == t1.shape == (batch_size, *shape, 1) + if t0_lower is not None: + assert t0.shape == t0_lower.shape + t0 = torch.maximum(t0, t0_lower) + + intersected = t0 + self.min_t_range < t1 + t0 = torch.where(intersected, t0, torch.zeros_like(t0)) + t1 = torch.where(intersected, t1, torch.ones_like(t1)) + + return VolumeRange(t0=t0, t1=t1, intersected=intersected) + + +class StratifiedRaySampler(nn.Module): + """ + Instead of fixed intervals, a sample is drawn uniformly at random from each interval. + """ + + def __init__(self, depth_mode: str = "linear"): + """ + :param depth_mode: linear samples ts linearly in depth. harmonic ensures + closer points are sampled more densely. + """ + self.depth_mode = depth_mode + assert self.depth_mode in ("linear", "geometric", "harmonic") + + def sample( + self, + t0: torch.Tensor, + t1: torch.Tensor, + n_samples: int, + epsilon: float = 1e-3, + ) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + ones = [1] * (len(t0.shape) - 1) + ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) + + if self.depth_mode == "linear": + ts = t0 * (1.0 - ts) + t1 * ts + elif self.depth_mode == "geometric": + ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() + elif self.depth_mode == "harmonic": + # The original NeRF recommends this interpolation scheme for + # spherical scenes, but there could be some weird edge cases when + # the observer crosses from the inner to outer volume. + ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) + + mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) + upper = torch.cat([mids, t1], dim=-1) + lower = torch.cat([t0, mids], dim=-1) + # yiyi notes: add a random seed here for testing, don't forget to remove + torch.manual_seed(0) + t_rand = torch.rand_like(ts) + + ts = lower + (upper - lower) * t_rand + return ts.unsqueeze(-1) + + +class ImportanceRaySampler(nn.Module): + """ + Given the initial estimate of densities, this samples more from regions/bins expected to have objects. + """ + + def __init__( + self, + volume_range: VolumeRange, + ts: torch.Tensor, + weights: torch.Tensor, + blur_pool: bool = False, + alpha: float = 1e-5, + ): + """ + Args: + volume_range: the range in which a ray intersects the given volume. + ts: earlier samples from the coarse rendering step + weights: discretized version of density * transmittance + blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. + alpha: small value to add to weights. + """ + self.volume_range = volume_range + self.ts = ts.clone().detach() + self.weights = weights.clone().detach() + self.blur_pool = blur_pool + self.alpha = alpha + + @torch.no_grad() + def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + lower, upper, _ = self.volume_range.partition(self.ts) + + batch_size, *shape, n_coarse_samples, _ = self.ts.shape + + weights = self.weights + if self.blur_pool: + padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) + maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) + weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) + weights = weights + self.alpha + pmf = weights / weights.sum(dim=-2, keepdim=True) + inds = sample_pmf(pmf, n_samples) + assert inds.shape == (batch_size, *shape, n_samples, 1) + assert (inds >= 0).all() and (inds < n_coarse_samples).all() + + t_rand = torch.rand(inds.shape, device=inds.device) + lower_ = torch.gather(lower, -2, inds) + upper_ = torch.gather(upper, -2, inds) + + ts = lower_ + (upper_ - lower_) * t_rand + ts = torch.sort(ts, dim=-2).values + return ts + + +@dataclass +class MeshDecoderOutput(BaseOutput): + """ + A 3D triangle mesh with optional data at the vertices and faces. + + Args: + verts (`torch.Tensor` of shape `(N, 3)`): + array of vertext coordinates + faces (`torch.Tensor` of shape `(N, 3)`): + array of triangles, pointing to indices in verts. + vertext_channels (Dict): + vertext coordinates for each color channel + """ + + verts: torch.Tensor + faces: torch.Tensor + vertex_channels: Dict[str, torch.Tensor] + + +class MeshDecoder(nn.Module): + """ + Construct meshes from Signed distance functions (SDFs) using marching cubes method + """ + + def __init__(self): + super().__init__() + cases = torch.zeros(256, 5, 3, dtype=torch.long) + masks = torch.zeros(256, 5, dtype=torch.bool) + + self.register_buffer("cases", cases) + self.register_buffer("masks", masks) + + def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor): + """ + For a signed distance field, produce a mesh using marching cubes. + + :param field: a 3D tensor of field values, where negative values correspond + to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. + :param min_point: a tensor of shape [3] containing the point corresponding + to (0, 0, 0) in the field. + :param size: a tensor of shape [3] containing the per-axis distance from the + (0, 0, 0) field corner and the (-1, -1, -1) field corner. + """ + assert len(field.shape) == 3, "input must be a 3D scalar field" + dev = field.device + + cases = self.cases.to(dev) + masks = self.masks.to(dev) + + min_point = min_point.to(dev) + size = size.to(dev) + + grid_size = field.shape + grid_size_tensor = torch.tensor(grid_size).to(size) + + # Create bitmasks between 0 and 255 (inclusive) indicating the state + # of the eight corners of each cube. + bitmasks = (field > 0).to(torch.uint8) + bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) + bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) + bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) + + # Compute corner coordinates across the entire grid. + corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) + corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[ + :, None, None + ] + corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[ + :, None + ] + corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype) + + # Compute all vertices across all edges in the grid, even though we will + # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. + # These are all midpoints, and don't account for interpolation (which is + # done later based on the used edge midpoints). + edge_midpoints = torch.cat( + [ + ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), + ], + dim=0, + ) + + # Create a flat array of [X, Y, Z] indices for each cube. + cube_indices = torch.zeros( + grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long + ) + cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None] + cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None] + cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) + flat_cube_indices = cube_indices.reshape(-1, 3) + + # Create a flat array mapping each cube to 12 global edge indices. + edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) + + # Apply the LUT to figure out the triangles. + flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask + local_tris = cases[flat_bitmasks] + local_masks = masks[flat_bitmasks] + # Compute the global edge indices for the triangles. + global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( + local_tris.shape + ) + # Select the used triangles for each cube. + selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] + + # Now we have a bunch of indices into the full list of possible vertices, + # but we want to reduce this list to only the used vertices. + used_vertex_indices = torch.unique(selected_tris.view(-1)) + used_edge_midpoints = edge_midpoints[used_vertex_indices] + old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) + old_index_to_new_index[used_vertex_indices] = torch.arange( + len(used_vertex_indices), device=dev, dtype=torch.long + ) + + # Rewrite the triangles to use the new indices + faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape) + + # Compute the actual interpolated coordinates corresponding to edge midpoints. + v1 = torch.floor(used_edge_midpoints).to(torch.long) + v2 = torch.ceil(used_edge_midpoints).to(torch.long) + s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] + s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] + p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point + p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point + # The signs of s1 and s2 should be different. We want to find + # t such that t*s2 + (1-t)*s1 = 0. + t = (s1 / (s1 - s2))[:, None] + verts = t * p2 + (1 - t) * p1 + + return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None) + + +@dataclass +class MLPNeRFModelOutput(BaseOutput): + density: torch.Tensor + signed_distance: torch.Tensor + channels: torch.Tensor + ts: torch.Tensor + + +class MLPNeRSTFModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + ): + super().__init__() + + # Instantiate the MLP + + # Find out the dimension of encoded position and direction + dummy = torch.eye(1, 3) + d_posenc_pos = encode_position(position=dummy).shape[-1] + d_posenc_dir = encode_direction(position=dummy).shape[-1] + + mlp_widths = [d_hidden] * n_hidden_layers + input_widths = [d_posenc_pos] + mlp_widths + output_widths = mlp_widths + [n_output] + + if insert_direction_at is not None: + input_widths[insert_direction_at] += d_posenc_dir + + self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) + + if act_fn == "swish": + # self.activation = swish + # yiyi testing: + self.activation = lambda x: F.silu(x) + else: + raise ValueError(f"Unsupported activation function {act_fn}") + + self.sdf_activation = torch.tanh + self.density_activation = torch.nn.functional.relu + self.channel_activation = torch.sigmoid + + def map_indices_to_keys(self, output): + h_map = { + "sdf": (0, 1), + "density_coarse": (1, 2), + "density_fine": (2, 3), + "stf": (3, 6), + "nerf_coarse": (6, 9), + "nerf_fine": (9, 12), + } + + mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} + + return mapped_output + + def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"): + h = encode_position(position) + + h_preact = h + h_directionless = None + for i, layer in enumerate(self.mlp): + if i == self.config.insert_direction_at: # 4 in the config + h_directionless = h_preact + h_direction = encode_direction(position, direction=direction) + h = torch.cat([h, h_direction], dim=-1) + + h = layer(h) + + h_preact = h + + if i < len(self.mlp) - 1: + h = self.activation(h) + + h_final = h + if h_directionless is None: + h_directionless = h_preact + + activation = self.map_indices_to_keys(h_final) + + if nerf_level == "coarse": + h_density = activation["density_coarse"] + else: + h_density = activation["density_fine"] + + if rendering_mode == "nerf": + if nerf_level == "coarse": + h_channels = activation["nerf_coarse"] + else: + h_channels = activation["nerf_fine"] + + elif rendering_mode == "stf": + h_channels = activation["stf"] + + density = self.density_activation(h_density) + signed_distance = self.sdf_activation(activation["sdf"]) + channels = self.channel_activation(h_channels) + + # yiyi notes: I think signed_distance is not used + return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) + + +class ChannelsProj(nn.Module): + def __init__( + self, + *, + vectors: int, + channels: int, + d_latent: int, + ): + super().__init__() + self.proj = nn.Linear(d_latent, vectors * channels) + self.norm = nn.LayerNorm(channels) + self.d_latent = d_latent + self.vectors = vectors + self.channels = channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_bvd = x + w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) + b_vc = self.proj.bias.view(1, self.vectors, self.channels) + h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) + h = self.norm(h) + + h = h + b_vc + return h + + +class ShapEParamsProjModel(ModelMixin, ConfigMixin): + """ + project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). + + For more details, see the original paper: + """ + + @register_to_config + def __init__( + self, + *, + param_names: Tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: Tuple[Tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + ): + super().__init__() + + # check inputs + if len(param_names) != len(param_shapes): + raise ValueError("Must provide same number of `param_names` as `param_shapes`") + self.projections = nn.ModuleDict({}) + for k, (vectors, channels) in zip(param_names, param_shapes): + self.projections[_sanitize_name(k)] = ChannelsProj( + vectors=vectors, + channels=channels, + d_latent=d_latent, + ) + + def forward(self, x: torch.Tensor): + out = {} + start = 0 + for k, shape in zip(self.config.param_names, self.config.param_shapes): + vectors, _ = shape + end = start + vectors + x_bvd = x[:, start:end] + out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) + start = end + return out + + +class ShapERenderer(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + *, + param_names: Tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: Tuple[Tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + background: Tuple[float] = ( + 255.0, + 255.0, + 255.0, + ), + ): + super().__init__() + + self.params_proj = ShapEParamsProjModel( + param_names=param_names, + param_shapes=param_shapes, + d_latent=d_latent, + ) + self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) + self.void = VoidNeRFModel(background=background, channel_scale=255.0) + self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) + self.mesh_decoder = MeshDecoder() + + @torch.no_grad() + def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): + """ + Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below + with some abuse of notations) + + C(r) := sum( + transmittance(t[i]) * integrate( + lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], + ) for i in range(len(parts)) + ) + transmittance(t[-1]) * void_model(t[-1]).channels + + where + + 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through + the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are + obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t + where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the + shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and + transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], + math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + + Args: + rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: + number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including + + :return: A tuple of + - `channels` + - A importance samplers for additional fine-grained rendering + - raw model output + """ + origin, direction = rays[..., 0, :], rays[..., 1, :] + + # Integrate over [t[i], t[i + 1]] + + # 1 Intersect the rays with the current volume and sample ts to integrate along. + vrange = self.volume.intersect(origin, direction, t0_lower=None) + ts = sampler.sample(vrange.t0, vrange.t1, n_samples) + ts = ts.to(rays.dtype) + + if prev_model_out is not None: + # Append the previous ts now before fprop because previous + # rendering used a different model and we can't reuse the output. + ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values + + batch_size, *_shape, _t0_dim = vrange.t0.shape + _, *ts_shape, _ts_dim = ts.shape + + # 2. Get the points along the ray and query the model + directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) + positions = origin.unsqueeze(-2) + ts * directions + + directions = directions.to(self.mlp.dtype) + positions = positions.to(self.mlp.dtype) + + optional_directions = directions if render_with_direction else None + + model_out = self.mlp( + position=positions, + direction=optional_directions, + ts=ts, + nerf_level="coarse" if prev_model_out is None else "fine", + ) + + # 3. Integrate the model results + channels, weights, transmittance = integrate_samples( + vrange, model_out.ts, model_out.density, model_out.channels + ) + + # 4. Clean up results that do not intersect with the volume. + transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) + channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) + # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + channels = channels + transmittance * self.void(origin) + + weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) + + return channels, weighted_sampler, model_out + + @torch.no_grad() + def decode_to_image( + self, + latents, + device, + size: int = 64, + ray_batch_size: int = 4096, + n_coarse_samples=64, + n_fine_samples=128, + ): + # project the parameters from the generated latents + projected_params = self.params_proj(latents) + + # update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # create cameras object + camera = create_pan_cameras(size) + rays = camera.camera_rays + rays = rays.to(device) + n_batches = rays.shape[1] // ray_batch_size + + coarse_sampler = StratifiedRaySampler() + + images = [] + + for idx in range(n_batches): + rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] + + # render rays with coarse, stratified samples. + _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) + # Then, render with additional importance-weighted ray samples. + channels, _, _ = self.render_rays( + rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out + ) + + images.append(channels) + + images = torch.cat(images, dim=1) + images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) + + return images + + @torch.no_grad() + def decode_to_mesh( + self, + latents, + device, + grid_size: int = 128, + query_batch_size: int = 4096, + texture_channels: Tuple = ("R", "G", "B"), + ): + # 1. project the parameters from the generated latents + projected_params = self.params_proj(latents) + + # 2. update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # 3. decoding with STF rendering + # 3.1 query the SDF values at vertices along a regular 128**3 grid + + query_points = volume_query_points(self.volume, grid_size) + query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype) + + fields = [] + + for idx in range(0, query_positions.shape[1], query_batch_size): + query_batch = query_positions[:, idx : idx + query_batch_size] + + model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + fields.append(model_out.signed_distance) + + # predicted SDF values + fields = torch.cat(fields, dim=1) + fields = fields.float() + + assert ( + len(fields.shape) == 3 and fields.shape[-1] == 1 + ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" + + fields = fields.reshape(1, *([grid_size] * 3)) + + # create grid 128 x 128 x 128 + # - force a negative border around the SDFs to close off all the models. + full_grid = torch.zeros( + 1, + grid_size + 2, + grid_size + 2, + grid_size + 2, + device=fields.device, + dtype=fields.dtype, + ) + full_grid.fill_(-1.0) + full_grid[:, 1:-1, 1:-1, 1:-1] = fields + fields = full_grid + + # apply a differentiable implementation of Marching Cubes to construct meshs + raw_meshes = [] + mesh_mask = [] + + for field in fields: + raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min) + mesh_mask.append(True) + raw_meshes.append(raw_mesh) + + mesh_mask = torch.tensor(mesh_mask, device=fields.device) + max_vertices = max(len(m.verts) for m in raw_meshes) + + # 3.2. query the texture color head at each vertex of the resulting mesh. + texture_query_positions = torch.stack( + [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], + dim=0, + ) + texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype) + + textures = [] + + for idx in range(0, texture_query_positions.shape[1], query_batch_size): + query_batch = texture_query_positions[:, idx : idx + query_batch_size] + + texture_model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + textures.append(texture_model_out.channels) + + # predict texture color + textures = torch.cat(textures, dim=1) + + textures = _convert_srgb_to_linear(textures) + textures = textures.float() + + # 3.3 augument the mesh with texture data + assert len(textures.shape) == 3 and textures.shape[-1] == len( + texture_channels + ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + + for m, texture in zip(raw_meshes, textures): + texture = texture[: len(m.verts)] + m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1))) + + return raw_meshes[0] diff --git a/diffusers/src/diffusers/pipelines/stable_audio/__init__.py b/diffusers/src/diffusers/pipelines/stable_audio/__init__.py new file mode 100755 index 0000000..dfdd419 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_audio/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"] + _import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_stable_audio import StableAudioProjectionModel + from .pipeline_stable_audio import StableAudioPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py b/diffusers/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py new file mode 100755 index 0000000..b8f8a70 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py @@ -0,0 +1,158 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import pi +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, times: torch.Tensor) -> torch.Tensor: + times = times[..., None] + freqs = times * self.weights[None] * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((times, fouriered), dim=-1) + return fouriered + + +@dataclass +class StableAudioProjectionModelOutput(BaseOutput): + """ + Args: + Class for StableAudio projection layer's outputs. + text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. + seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio start hidden states. + seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio end hidden states. + """ + + text_hidden_states: Optional[torch.Tensor] = None + seconds_start_hidden_states: Optional[torch.Tensor] = None + seconds_end_hidden_states: Optional[torch.Tensor] = None + + +class StableAudioNumberConditioner(nn.Module): + """ + A simple linear projection model to map numbers to a latent space. + + Args: + number_embedding_dim (`int`): + Dimensionality of the number embeddings. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + internal_dim (`int`): + Dimensionality of the intermediate number hidden states. + """ + + def __init__( + self, + number_embedding_dim, + min_value, + max_value, + internal_dim: Optional[int] = 256, + ): + super().__init__() + self.time_positional_embedding = nn.Sequential( + StableAudioPositionalEmbedding(internal_dim), + nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), + ) + + self.number_embedding_dim = number_embedding_dim + self.min_value = min_value + self.max_value = max_value + + def forward( + self, + floats: torch.Tensor, + ): + floats = floats.clamp(self.min_value, self.max_value) + + normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) + + # Cast floats to same type as embedder + embedder_dtype = next(self.time_positional_embedding.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + embedding = self.time_positional_embedding(normalized_floats) + float_embeds = embedding.view(-1, 1, self.number_embedding_dim) + + return float_embeds + + +class StableAudioProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map the conditioning values to a shared latent space. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the text encoder (T5). + conditioning_dim (`int`): + Dimensionality of the output conditioning tensors. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + """ + + @register_to_config + def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): + super().__init__() + self.text_projection = ( + nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) + ) + self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + + def forward( + self, + text_hidden_states: Optional[torch.Tensor] = None, + start_seconds: Optional[torch.Tensor] = None, + end_seconds: Optional[torch.Tensor] = None, + ): + text_hidden_states = ( + text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) + ) + seconds_start_hidden_states = ( + start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) + ) + seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) + + return StableAudioProjectionModelOutput( + text_hidden_states=text_hidden_states, + seconds_start_hidden_states=seconds_start_hidden_states, + seconds_end_hidden_states=seconds_end_hidden_states, + ) diff --git a/diffusers/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/diffusers/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py new file mode 100755 index 0000000..4fe082d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -0,0 +1,745 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, +) + +from ...models import AutoencoderOobleck, StableAudioDiTModel +from ...models.embeddings import get_1d_rotary_pos_embed +from ...schedulers import EDMDPMSolverMultistepScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .modeling_stable_audio import StableAudioProjectionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> import soundfile as sf + >>> from diffusers import StableAudioPipeline + + >>> repo_id = "stabilityai/stable-audio-open-1.0" + >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_end_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> output = audio[0].T.float().cpu().numpy() + >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) + ``` +""" + + +class StableAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using StableAudio. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderOobleck`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder. StableAudio uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. + projection_model ([`StableAudioProjectionModel`]): + A trained model used to linearly project the hidden-states from the text encoder model and the start and + end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to + give the input to the transformer model. + tokenizer ([`~transformers.T5Tokenizer`]): + Tokenizer to tokenize text for the frozen text-encoder. + transformer ([`StableAudioDiTModel`]): + A `StableAudioDiTModel` to denoise the encoded audio latents. + scheduler ([`EDMDPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. + """ + + model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" + + def __init__( + self, + vae: AutoencoderOobleck, + text_encoder: T5EncoderModel, + projection_model: StableAudioProjectionModel, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + transformer: StableAudioDiTModel, + scheduler: EDMDPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + projection_model=projection_model, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # 1. Tokenize text + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " + f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if do_classifier_free_guidance and negative_prompt is not None: + uncond_tokens: List[str] + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # 1. Tokenize text + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if negative_attention_mask is not None: + # set the masked tokens to the null embed + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 + ) + + # 3. Project prompt_embeds and negative_prompt_embeds + if do_classifier_free_guidance and negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the negative and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance, + batch_size, + ): + audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] + audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + # Cast the inputs to floats + audio_start_in_s = [float(x) for x in audio_start_in_s] + audio_start_in_s = torch.tensor(audio_start_in_s).to(device) + + audio_end_in_s = [float(x) for x in audio_end_in_s] + audio_end_in_s = torch.tensor(audio_end_in_s).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + # For classifier free guidance, we need to do two forward passes. + # Here we repeat the audio hidden states to avoid doing two forward passes + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + initial_audio_waveforms=None, + initial_audio_sampling_rate=None, + ): + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " + ) + + if ( + audio_start_in_s < self.projection_model.config.min_value + or audio_start_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_start_in_s}." + ) + + if ( + audio_end_in_s < self.projection_model.config.min_value + or audio_end_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_end_in_s}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds`. Cannot leave" + "`prompt` undefined without specifying `prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: + raise ValueError( + "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." + ) + + if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: + raise ValueError( + f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." + "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " + ) + + def prepare_latents( + self, + batch_size, + num_channels_vae, + sample_size, + dtype, + device, + generator, + latents=None, + initial_audio_waveforms=None, + num_waveforms_per_prompt=None, + audio_channels=None, + ): + shape = (batch_size, num_channels_vae, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # encode the initial audio for use by the model + if initial_audio_waveforms is not None: + # check dimension + if initial_audio_waveforms.ndim == 2: + initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) + elif initial_audio_waveforms.ndim != 3: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" + ) + + audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length + audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) + + # check num_channels + if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: + initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) + elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: + initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) + + if initial_audio_waveforms.shape[:2] != audio_shape[:2]: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" + ) + + # crop or pad + audio_length = initial_audio_waveforms.shape[-1] + if audio_length < audio_vae_length: + logger.warning( + f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." + ) + elif audio_length > audio_vae_length: + logger.warning( + f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." + ) + + audio = initial_audio_waveforms.new_zeros(audio_shape) + audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] + + encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) + encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + latents = encoded_audio + latents + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_end_in_s: Optional[float] = None, + audio_start_in_s: Optional[float] = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + initial_audio_waveforms: Optional[torch.Tensor] = None, + initial_audio_sampling_rate: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + output_type: Optional[str] = "pt", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_end_in_s (`float`, *optional*, defaults to 47.55): + Audio end index in seconds. + audio_start_in_s (`float`, *optional*, defaults to 0): + Audio start index in seconds. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + initial_audio_waveforms (`torch.Tensor`, *optional*): + Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape + `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` + corresponds to the number of prompts passed to the model. + initial_audio_sampling_rate (`int`, *optional*): + Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to latent length + downsample_ratio = self.vae.hop_length + + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + initial_audio_waveforms, + initial_audio_sampling_rate, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create text_audio_duration_embeds and audio_duration_embeds + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 + ) + + audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) + + # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and + # to concatenate it to the embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like( + text_audio_duration_embeds, device=text_audio_duration_embeds.device + ) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 + ) + audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) + + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + initial_audio_waveforms, + num_waveforms_per_prompt, + audio_channels=self.vae.config.audio_channels, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary positional embedding + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + audio = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = audio[:, :, waveform_start:waveform_end] + + if output_type == "np": + audio = audio.cpu().float().numpy() + + self.maybe_free_model_hooks() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/src/diffusers/pipelines/stable_cascade/__init__.py b/diffusers/src/diffusers/pipelines/stable_cascade/__init__.py new file mode 100755 index 0000000..5270cb9 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_cascade/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"] + _import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"] + _import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_cascade import StableCascadeDecoderPipeline + from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline + from .pipeline_stable_cascade_prior import StableCascadePriorPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py new file mode 100755 index 0000000..111ccc4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -0,0 +1,528 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import is_torch_version, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline + + >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( + ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 + ... ).to("cuda") + >>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain( + ... "stabilityai/stable-cascade", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) + ``` +""" + + +class StableCascadeDecoderPipeline(DiffusionPipeline): + """ + Pipeline for generating images from the Stable Cascade model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The CLIP tokenizer. + text_encoder (`CLIPTextModel`): + The CLIP text encoder. + decoder ([`StableCascadeUNet`]): + The Stable Cascade decoder unet. + vqgan ([`PaellaVQModel`]): + The VQGAN model. + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_dim_scale (float, `optional`, defaults to 10.67): + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. + """ + + unet_name = "decoder" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "text_encoder->decoder->vqgan" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_pooled", + "negative_prompt_embeds", + "image_embeddings", + ] + + def __init__( + self, + decoder: StableCascadeUNet, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + latent_dim_scale: float = 10.67, + ) -> None: + super().__init__() + self.register_modules( + decoder=decoder, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, + vqgan=vqgan, + ) + self.register_to_config(latent_dim_scale=latent_dim_scale) + + def prepare_latents( + self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler + ): + _, channels, height, width = image_embeddings.shape + latents_shape = ( + batch_size * num_images_per_prompt, + 4, + int(height * self.config.latent_dim_scale), + int(width * self.config.latent_dim_scale), + ) + + if latents is None: + latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_pooled: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, + ): + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True + ) + prompt_embeds = text_encoder_output.hidden_states[-1] + if prompt_embeds_pooled is None: + prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_input.attention_mask.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] + negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + seq_len = negative_prompt_embeds_pooled.shape[1] + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( + dtype=self.text_encoder.dtype, device=device + ) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled + + def check_inputs( + self, + prompt, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + def get_timestep_ratio_conditioning(self, t, alphas_cumprod): + s = torch.tensor([0.008]) + clamp_range = [0, 1] + min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 + var = alphas_cumprod[t] + var = var.clamp(*clamp_range) + s, min_var = s.to(var.device), min_var.to(var.device) + ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return ratio + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeddings: Union[torch.Tensor, List[torch.Tensor]], + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 10, + guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_pooled: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embedding (`torch.Tensor` or `List[torch.Tensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + # 0. Define commonly used variables + device = self._execution_device + dtype = self.decoder.dtype + self._guidance_scale = guidance_scale + if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: + raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + if isinstance(image_embeddings, list): + image_embeddings = torch.cat(image_embeddings, dim=0) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Compute the effective number of images per prompt + # We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1 + # This results in a case where a single prompt is associated with multiple image embeddings + # Divide the number of image embeddings by the batch size to determine if this is the case. + num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size) + + # 2. Encode caption + if prompt_embeds is None and negative_prompt_embeds is None: + _, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt( + prompt=prompt, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) + + # The pooled embeds from the prior are pooled again before being passed to the decoder + prompt_embeds_pooled = ( + torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) + if self.do_classifier_free_guidance + else prompt_embeds_pooled + ) + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents( + batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler + ) + + if isinstance(self.scheduler, DDPMWuerstchenScheduler): + timesteps = timesteps[:-1] + else: + if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: + self.scheduler.config.clip_sample = False # disample sample clipping + logger.warning(" set `clip_sample` to be False") + + # 6. Run denoising loop + if hasattr(self.scheduler, "betas"): + alphas = 1.0 - self.scheduler.betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + else: + alphas_cumprod = [] + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + if len(alphas_cumprod) > 0: + timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) + timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) + else: + timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) + else: + timestep_ratio = t.expand(latents.size(0)).to(dtype) + + # 7. Denoise latents + predicted_latents = self.decoder( + sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, + clip_text_pooled=prompt_embeds_pooled, + effnet=effnet, + return_dict=False, + )[0] + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) + predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) + + # 9. Renoise latents to next timestep + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + timestep_ratio = t + latents = self.scheduler.step( + model_output=predicted_latents, + timestep=timestep_ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + # 10. Scale and decode the image latents with vq-vae + latents = self.vqgan.config.scale_factor * latents + images = self.vqgan.decode(latents).sample.clamp(0, 1) + if output_type == "np": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + elif output_type == "pil": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + images = self.numpy_to_pil(images) + else: + images = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return images + return ImagePipelineOutput(images) diff --git a/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py new file mode 100755 index 0000000..6724b60 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -0,0 +1,315 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, List, Optional, Union + +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import is_torch_version, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel +from .pipeline_stable_cascade import StableCascadeDecoderPipeline +from .pipeline_stable_cascade_prior import StableCascadePriorPipeline + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadeCombinedPipeline + + >>> pipe = StableCascadeCombinedPipeline.from_pretrained( + ... "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.enable_model_cpu_offload() + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> images = pipe(prompt=prompt) + ``` +""" + + +class StableCascadeCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Stable Cascade. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The decoder tokenizer to be used for text inputs. + text_encoder (`CLIPTextModel`): + The decoder text encoder to be used for text inputs. + decoder (`StableCascadeUNet`): + The decoder model to be used for decoder image generation pipeline. + scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for decoder image generation pipeline. + vqgan (`PaellaVQModel`): + The VQGAN model to be used for decoder image generation pipeline. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + prior_prior (`StableCascadeUNet`): + The prior model to be used for prior pipeline. + prior_scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for prior pipeline. + """ + + _load_connected_pipes = True + _optional_components = ["prior_feature_extractor", "prior_image_encoder"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: StableCascadeUNet, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + prior_prior: StableCascadeUNet, + prior_text_encoder: CLIPTextModel, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: DDPMWuerstchenScheduler, + prior_feature_extractor: Optional[CLIPImageProcessor] = None, + prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_prior=prior_prior, + prior_scheduler=prior_scheduler, + prior_feature_extractor=prior_feature_extractor, + prior_image_encoder=prior_image_encoder, + ) + self.prior_pipe = StableCascadePriorPipeline( + prior=prior_prior, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_encoder=prior_image_encoder, + feature_extractor=prior_feature_extractor, + ) + self.decoder_pipe = StableCascadeDecoderPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, + height: int = 512, + width: int = 512, + prior_num_inference_steps: int = 60, + prior_guidance_scale: float = 4.0, + num_inference_steps: int = 12, + decoder_guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_pooled: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation for the prior and decoder. + images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*): + The images to guide the image generation for the prior. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked + to the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60): + The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. For more specific timestep spacing, you can pass customized + `prior_timesteps` + num_inference_steps (`int`, *optional*, defaults to 12): + The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. For more specific timestep spacing, you can pass customized + `timesteps` + decoder_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + dtype = self.decoder_pipe.decoder.dtype + if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: + raise ValueError( + "`StableCascadeCombinedPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype." + ) + + prior_outputs = self.prior_pipe( + prompt=prompt if prompt_embeds is None else None, + images=images, + height=height, + width=width, + num_inference_steps=prior_num_inference_steps, + guidance_scale=prior_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + output_type="pt", + return_dict=True, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + ) + image_embeddings = prior_outputs.image_embeddings + prompt_embeds = prior_outputs.get("prompt_embeds", None) + prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None) + negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None) + negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None) + + outputs = self.decoder_pipe( + image_embeddings=image_embeddings, + prompt=prompt if prompt_embeds is None else None, + num_inference_steps=num_inference_steps, + guidance_scale=decoder_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + return outputs diff --git a/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py new file mode 100755 index 0000000..058dbf6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -0,0 +1,639 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import ceil +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadePriorPipeline + + >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( + ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + ``` +""" + + +@dataclass +class StableCascadePriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeddings (`torch.Tensor` or `np.ndarray`) + Prior image embeddings for text prompt + prompt_embeds (`torch.Tensor`): + Text embeddings for the prompt. + negative_prompt_embeds (`torch.Tensor`): + Text embeddings for the negative prompt. + """ + + image_embeddings: Union[torch.Tensor, np.ndarray] + prompt_embeds: Union[torch.Tensor, np.ndarray] + prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] + negative_prompt_embeds: Union[torch.Tensor, np.ndarray] + negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] + + +class StableCascadePriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Stable Cascade. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`StableCascadeUNet`]): + The Stable Cascade prior to approximate the image embedding from the text and/or image embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + resolution_multiple ('float', *optional*, defaults to 42.67): + Default resolution for multiple images generated. + """ + + unet_name = "prior" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "image_encoder->text_encoder->prior" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + prior: StableCascadeUNet, + scheduler: DDPMWuerstchenScheduler, + resolution_multiple: float = 42.67, + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + prior=prior, + scheduler=scheduler, + ) + self.register_to_config(resolution_multiple=resolution_multiple) + + def prepare_latents( + self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler + ): + latent_shape = ( + num_images_per_prompt * batch_size, + self.prior.config.in_channels, + ceil(height / self.config.resolution_multiple), + ceil(width / self.config.resolution_multiple), + ) + + if latents is None: + latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != latent_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_pooled: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, + ): + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True + ) + prompt_embeds = text_encoder_output.hidden_states[-1] + if prompt_embeds_pooled is None: + prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_input.attention_mask.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] + negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + seq_len = negative_prompt_embeds_pooled.shape[1] + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( + dtype=self.text_encoder.dtype, device=device + ) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled + + def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt): + image_embeds = [] + for image in images: + image = self.feature_extractor(image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=dtype) + image_embed = self.image_encoder(image).image_embeds.unsqueeze(1) + image_embeds.append(image_embed) + image_embeds = torch.cat(image_embeds, dim=1) + + image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) + negative_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, negative_image_embeds + + def check_inputs( + self, + prompt, + images=None, + image_embeds=None, + negative_prompt=None, + prompt_embeds=None, + prompt_embeds_pooled=None, + negative_prompt_embeds=None, + negative_prompt_embeds_pooled=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and prompt_embeds_pooled is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + + if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + + if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: + if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: + raise ValueError( + "`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed" + f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !=" + f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}." + ) + + if image_embeds is not None and images is not None: + raise ValueError( + f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + + if images: + for i, image in enumerate(images): + if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise TypeError( + f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got" + f"{type(image)} for image number {i}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + def get_timestep_ratio_conditioning(self, t, alphas_cumprod): + s = torch.tensor([0.008]) + clamp_range = [0, 1] + min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 + var = alphas_cumprod[t] + var = var.clamp(*clamp_range) + s, min_var = s.to(var.device), min_var.to(var.device) + ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return ratio + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 20, + timesteps: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_pooled: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pt", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 60): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` + input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting. If + not provided, image embeddings will be generated from `image` input argument if existing. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + # 0. Define commonly used variables + device = self._execution_device + dtype = next(self.prior.parameters()).dtype + self._guidance_scale = guidance_scale + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + images=images, + image_embeds=image_embeds, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Encode caption + images + ( + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) = self.encode_prompt( + prompt=prompt, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) + + if images is not None: + image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image( + images=images, + device=device, + dtype=dtype, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + elif image_embeds is not None: + image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) + uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) + else: + image_embeds_pooled = torch.zeros( + batch_size * num_images_per_prompt, + 1, + self.prior.config.clip_image_in_channels, + device=device, + dtype=dtype, + ) + uncond_image_embeds_pooled = torch.zeros( + batch_size * num_images_per_prompt, + 1, + self.prior.config.clip_image_in_channels, + device=device, + dtype=dtype, + ) + + if self.do_classifier_free_guidance: + image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0) + else: + image_embeds = image_embeds_pooled + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + text_encoder_pooled = ( + torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) + if negative_prompt_embeds is not None + else prompt_embeds_pooled + ) + + # 4. Prepare and set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents( + batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler + ) + + if isinstance(self.scheduler, DDPMWuerstchenScheduler): + timesteps = timesteps[:-1] + else: + if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: + self.scheduler.config.clip_sample = False # disample sample clipping + logger.warning(" set `clip_sample` to be False") + # 6. Run denoising loop + if hasattr(self.scheduler, "betas"): + alphas = 1.0 - self.scheduler.betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + else: + alphas_cumprod = [] + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + if len(alphas_cumprod) > 0: + timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) + timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) + else: + timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) + else: + timestep_ratio = t.expand(latents.size(0)).to(dtype) + # 7. Denoise image embeddings + predicted_image_embedding = self.prior( + sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, + clip_text_pooled=text_encoder_pooled, + clip_text=text_encoder_hidden_states, + clip_img=image_embeds, + return_dict=False, + )[0] + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale + ) + + # 9. Renoise latents to next timestep + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + timestep_ratio = t + latents = self.scheduler.step( + model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "np": + latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + negative_prompt_embeds = ( + negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None + ) # float() as bfloat16-> numpy doesnt work + + if not return_dict: + return ( + latents, + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) + + return StableCascadePriorPipelineOutput( + image_embeddings=latents, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/README.md b/diffusers/src/diffusers/pipelines/stable_diffusion/README.md new file mode 100755 index 0000000..5b229fd --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/README.md @@ -0,0 +1,176 @@ +# Stable Diffusion + +## Overview + +Stable Diffusion was proposed in [Stable Diffusion Announcement](https://stability.ai/blog/stable-diffusion-announcement) by Patrick Esser and Robin Rombach and the Stability AI team. + +The summary of the model is the following: + +*Stable Diffusion is a text-to-image model that will empower billions of people to create stunning art within seconds. It is a breakthrough in speed and quality meaning that it can run on consumer GPUs. You can see some of the amazing output that has been created by this model without pre or post-processing on this page. The model itself builds upon the work of the team at CompVis and Runway in their widely used latent diffusion model combined with insights from the conditional diffusion models by our lead generative AI developer Katherine Crowson, Dall-E 2 by Open AI, Imagen by Google Brain and many others. We are delighted that AI media generation is a cooperative field and hope it can continue this way to bring the gift of creativity to all.* + +## Tips: + +- Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model. +- An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion). +- If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can +download the weights with `git lfs install; git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5` and instead pass the local path to the cloned folder to `from_pretrained` as shown below. +- Stable Diffusion can work with a variety of different samplers as is shown below. + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) + +## Examples: + +### Using Stable Diffusion without being logged into the Hub. + +If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`. + +```python +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") +``` + +This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-5"`: + +``` +git lfs install +git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5 +``` + +and simply passing the local path to `from_pretrained`: + +```python +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") +``` + +### Text-to-Image with default PLMS scheduler + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### Text-to-Image with DDIM scheduler + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline, DDIMScheduler + +scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + +pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + scheduler=scheduler, +).to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### Text-to-Image with K-LMS scheduler + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler + +lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + +pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + scheduler=lms, +).to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### CycleDiffusion using Stable Diffusion and DDIM scheduler + +```python +import requests +import torch +from PIL import Image +from io import BytesIO + +from diffusers import CycleDiffusionPipeline, DDIMScheduler + + +# load the scheduler. CycleDiffusion only supports stochastic schedulers. + +# load the pipeline +# make sure you're logged in with `huggingface-cli login` +model_id_or_path = "CompVis/stable-diffusion-v1-4" +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") +pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") + +# let's download an initial image +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("horse.png") + +# let's specify a prompt +source_prompt = "An astronaut riding a horse" +prompt = "An astronaut riding an elephant" + +# call the pipeline +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.8, + guidance_scale=2, + source_guidance_scale=1, +).images[0] + +image.save("horse_to_elephant.png") + +# let's try another example +# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("black.png") + +source_prompt = "A black colored car" +prompt = "A blue colored car" + +# call the pipeline +torch.manual_seed(0) +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, +).images[0] + +image.save("black_to_blue.png") +``` diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion/__init__.py new file mode 100755 index 0000000..8ce08ae --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -0,0 +1,202 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_k_diffusion_available, + is_k_diffusion_version, + is_onnx_available, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["StableDiffusionPipelineOutput"]} + +if is_transformers_available() and is_flax_available(): + _import_structure["pipeline_output"].extend(["FlaxStableDiffusionPipelineOutput"]) +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["clip_image_project_model"] = ["CLIPImageProjection"] + _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] + _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] + _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"] + _import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"] + _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] + _import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"] + _import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"] + _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] + _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] + _import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"] + _import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"] + _import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"] + _import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"] + _import_structure["stable_unclip_image_normalizer"] = ["StableUnCLIPImageNormalizer"] +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionImageVariationPipeline, + ) + + _dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline}) +else: + _import_structure["pipeline_stable_diffusion_image_variation"] = ["StableDiffusionImageVariationPipeline"] +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionDepth2ImgPipeline, + ) + + _dummy_objects.update( + { + "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, + } + ) +else: + _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] + +try: + if not (is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) +else: + _import_structure["pipeline_onnx_stable_diffusion"] = [ + "OnnxStableDiffusionPipeline", + "StableDiffusionOnnxPipeline", + ] + _import_structure["pipeline_onnx_stable_diffusion_img2img"] = ["OnnxStableDiffusionImg2ImgPipeline"] + _import_structure["pipeline_onnx_stable_diffusion_inpaint"] = ["OnnxStableDiffusionInpaintPipeline"] + _import_structure["pipeline_onnx_stable_diffusion_inpaint_legacy"] = ["OnnxStableDiffusionInpaintPipelineLegacy"] + _import_structure["pipeline_onnx_stable_diffusion_upscale"] = ["OnnxStableDiffusionUpscalePipeline"] + +if is_transformers_available() and is_flax_available(): + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + + _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) + _import_structure["pipeline_flax_stable_diffusion"] = ["FlaxStableDiffusionPipeline"] + _import_structure["pipeline_flax_stable_diffusion_img2img"] = ["FlaxStableDiffusionImg2ImgPipeline"] + _import_structure["pipeline_flax_stable_diffusion_inpaint"] = ["FlaxStableDiffusionInpaintPipeline"] + _import_structure["safety_checker_flax"] = ["FlaxStableDiffusionSafetyChecker"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .clip_image_project_model import CLIPImageProjection + from .pipeline_stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, + ) + from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline + from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline + from .pipeline_stable_diffusion_instruct_pix2pix import ( + StableDiffusionInstructPix2PixPipeline, + ) + from .pipeline_stable_diffusion_latent_upscale import ( + StableDiffusionLatentUpscalePipeline, + ) + from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline + from .pipeline_stable_unclip import StableUnCLIPPipeline + from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline + from .safety_checker import StableDiffusionSafetyChecker + from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionImageVariationPipeline, + ) + else: + from .pipeline_stable_diffusion_image_variation import ( + StableDiffusionImageVariationPipeline, + ) + + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline + else: + from .pipeline_stable_diffusion_depth2img import ( + StableDiffusionDepth2ImgPipeline, + ) + + try: + if not (is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_onnx_objects import * + else: + from .pipeline_onnx_stable_diffusion import ( + OnnxStableDiffusionPipeline, + StableDiffusionOnnxPipeline, + ) + from .pipeline_onnx_stable_diffusion_img2img import ( + OnnxStableDiffusionImg2ImgPipeline, + ) + from .pipeline_onnx_stable_diffusion_inpaint import ( + OnnxStableDiffusionInpaintPipeline, + ) + from .pipeline_onnx_stable_diffusion_upscale import ( + OnnxStableDiffusionUpscalePipeline, + ) + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_objects import * + else: + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .pipeline_flax_stable_diffusion_img2img import ( + FlaxStableDiffusionImg2ImgPipeline, + ) + from .pipeline_flax_stable_diffusion_inpaint import ( + FlaxStableDiffusionInpaintPipeline, + ) + from .pipeline_output import FlaxStableDiffusionPipelineOutput + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py b/diffusers/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py new file mode 100755 index 0000000..71f9d97 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py @@ -0,0 +1,29 @@ +# Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class CLIPImageProjection(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, hidden_size: int = 768): + super().__init__() + self.hidden_size = hidden_size + self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def forward(self, x): + return self.project(x) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/diffusers/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py new file mode 100755 index 0000000..53dc98a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -0,0 +1,1869 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO +from typing import Dict, Optional, Union + +import requests +import torch +import yaml +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from ...models import ( + AutoencoderKL, + ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from ...utils import is_accelerate_available, logging +from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from ..paint_by_example import PaintByExampleImageEncoder +from ..pipeline_utils import DiffusionPipeline +from .safety_checker import StableDiffusionSafetyChecker +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + else: + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + if controlnet: + config["conditioning_channels"] = unet_params["hint_channels"] + else: + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in sorted(output_block_list.items())} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + try: + config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + else: + text_model = text_encoder + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix="cond_stage_model.model.", + has_projection=False, + local_files_only=False, + **config_kwargs, +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + try: + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config, local_files_only=False): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config["model"]["params"]["embedder_config"] + + sd_clip_image_embedder_class = image_embedder_config["target"] + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config["model"]["params"]["noise_aug_config"] + noise_aug_class = noise_aug_config["target"] + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + if is_accelerate_available(): + for param_name, param in converted_ctrl_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + adapter: Optional[bool] = None, + load_safety_checker: bool = True, + safety_checker: Optional[StableDiffusionSafetyChecker] = None, + feature_extractor: Optional[AutoFeatureExtractor] = None, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + vae=None, + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + config_files=None, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`): + Safety checker to use. If this parameter is `None`, the function will load a new instance of + [StableDiffusionSafetyChecker] by itself, if needed. + feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`): + Feature extractor to use. If this parameter is `None`, the function will load a new instance of + [AutoFeatureExtractor] by itself, if needed. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + vae (`AutoencoderKL`, *optional*, defaults to `None`): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If + this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + config_files (`Dict[str, str]`, *optional*, defaults to `None`): + A dictionary mapping from config file names to their contents. If this parameter is `None`, the function + will load the config files by itself, if needed. Valid keys are: + - `v1`: Config file for Stable Diffusion v1 + - `v2`: Config file for Stable Diffusion v2 + - `xl`: Config file for Stable Diffusion XL + - `xl_refiner`: Config file for Stable Diffusion XL Refiner + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionUpscalePipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if isinstance(checkpoint_path_or_dict, str): + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + else: + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + elif isinstance(checkpoint_path_or_dict, dict): + checkpoint = checkpoint_path_or_dict + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + is_upscale = pipeline_class == StableDiffusionUpscalePipeline + + config_url = None + + # model_type = "v1" + if config_files is not None and "v1" in config_files: + original_config_file = config_files["v1"] + else: + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + if config_files is not None and "v2" in config_files: + original_config_file = config_files["v2"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + if config_files is not None and "xl" in config_files: + original_config_file = config_files["xl"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + if config_files is not None and "xl_refiner" in config_files: + original_config_file = config_files["xl_refiner"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + if is_upscale: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" + + if config_url is not None: + original_config_file = BytesIO(requests.get(config_url).content) + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() + + original_config = yaml.safe_load(original_config_file) + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None + ): + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config["model"]["params"]["network_config"] is not None: + if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if pipeline_class is None: + # Check if we have a SDXL or SD model and initialize default pipeline + if model_type not in ["SDXL", "SDXL-Refiner"]: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + else: + pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline + + if num_in_channels is None and pipeline_class in [ + StableDiffusionInpaintPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLControlNetInpaintPipeline, + ]: + num_in_channels = 9 + if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: + num_in_channels = 7 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config["model"]["params"]: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + elif "network_config" in original_config["model"]["params"]: + original_config["model"]["params"]["network_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config["model"]["params"]: + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, path, image_size, upcast_attention, extract_ema + ) + + if "timesteps" in original_config["model"]["params"]: + num_train_timesteps = original_config["model"]["params"]["timesteps"] + else: + num_train_timesteps = 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + if "linear_start" in original_config["model"]["params"]: + beta_start = original_config["model"]["params"]["linear_start"] + else: + beta_start = 0.02 + + if "linear_end" in original_config["model"]["params"]: + beta_end = original_config["model"]["params"]["linear_end"] + else: + beta_end = 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + if pipeline_class == StableDiffusionUpscalePipeline: + image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None and vae is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config["model"] + and "scale_factor" in original_config["model"]["params"] + ): + vae_scaling_factor = original_config["model"]["params"]["scale_factor"] + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + elif vae is None: + vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + if text_encoder is None: + text_model = convert_open_clip_checkpoint( + checkpoint, config_name, local_files_only=local_files_only, **config_kwargs + ) + else: + text_model = text_encoder + + try: + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'." + ) + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + elif pipeline_class == StableDiffusionUpscalePipeline: + scheduler = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" + ) + low_res_scheduler = DDPMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + ) + + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + low_res_scheduler=low_res_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained( + karlo_model, subfolder="prior", local_files_only=local_files_only + ) + + try: + prior_tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained( + karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only + ) + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + try: + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." + ) + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + try: + tokenizer = ( + CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + if tokenizer is None + else tokenizer + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + is_refiner = model_type == "SDXL-Refiner" + + if (is_refiner is False) and (tokenizer is None): + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if (is_refiner is False) and (text_encoder is None): + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + if tokenizer_2 is None: + try: + tokenizer_2 = CLIPTokenizer.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." + ) + + if text_encoder_2 is None: + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model." + + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + + if is_accelerate_available(): # SBM Now move model to cpu. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + elif adapter: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + + else: + pipeline_kwargs = { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "unet": unet, + "scheduler": scheduler, + } + + if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or ( + pipeline_class == StableDiffusionXLInpaintPipeline + ): + pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) + + if is_refiner: + pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) + + pipe = pipeline_class(**pipeline_kwargs) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, +) -> DiffusionPipeline: + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + with open(original_config_file, "r") as f: + original_config_file = f.read() + original_config = yaml.safe_load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config["model"]["params"]: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py new file mode 100755 index 0000000..5d6ffd4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -0,0 +1,473 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from packaging import version +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import deprecate, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + + >>> from diffusers import FlaxStableDiffusionPipeline + + >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16 + ... ) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> prompt_ids = pipeline.prepare_inputs(prompt) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + + >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + latents (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + array is generated by sampling using the supplied random `generator`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + + + This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a + future release. + + + + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images).copy() + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i, 0] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 4, 5, 6), +) +def _p_generate( + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py new file mode 100755 index 0000000..7792bc0 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -0,0 +1,532 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_img = init_img.resize((768, 512)) + + >>> prompts = "A fantasy landscape, trending on artstation" + + >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... revision="flax", + ... dtype=jnp.bfloat16, + ... ) + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + >>> prompt_ids, processed_image = pipeline.prepare_inputs( + ... prompt=[prompts] * num_samples, image=[init_img] * num_samples + ... ) + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipeline( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... strength=0.75, + ... num_inference_steps=50, + ... jit=True, + ... height=512, + ... width=768, + ... ).images + + >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def get_timestep_start(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + return t_start + + def _generate( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + start_timestep: int, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + noise: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if noise is None: + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if noise.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") + + # Create init_latents + init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist + init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) + init_latents = self.vae.config.scaling_factor * init_latents + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) + + latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(start_timestep, num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + strength: float = 0.8, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + noise: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt_ids (`jnp.ndarray`): + The prompt or prompts to guide image generation. + image (`jnp.ndarray`): + Array representing an image batch to be used as the starting point. + params (`Dict` or `FrozenDict`): + Dictionary containing the model parameters/weights. + prng_seed (`jax.Array` or `jax.Array`): + Array containing random number generator key. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + noise (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. The array is generated by + sampling using the supplied random `generator`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + + + This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a + future release. + + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + start_timestep = self.get_timestep_start(num_inference_steps, strength) + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 5, 6, 7, 8), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py new file mode 100755 index 0000000..f6bb0ac --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -0,0 +1,589 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from packaging import version +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import PIL + >>> import requests + >>> from io import BytesIO + >>> from diffusers import FlaxStableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( + ... "xvjiarui/stable-diffusion-2-inpainting" + ... ) + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> init_image = num_samples * [init_image] + >>> mask_image = num_samples * [mask_image] + >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( + ... prompt, init_image, mask_image + ... ) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + >>> processed_masked_images = shard(processed_masked_images) + >>> processed_masks = shard(processed_masks) + + >>> images = pipeline( + ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True + ... ).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-guided image inpainting using Stable Diffusion. + + + + 🧪 This is an experimental feature! + + + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs( + self, + prompt: Union[str, List[str]], + image: Union[Image.Image, List[Image.Image]], + mask: Union[Image.Image, List[Image.Image]], + ): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + if not isinstance(mask, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(mask, Image.Image): + mask = [mask] + + processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) + processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) + # processed_masks[processed_masks < 0.5] = 0 + processed_masks = processed_masks.at[processed_masks < 0.5].set(0) + # processed_masks[processed_masks >= 0.5] = 1 + processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) + + processed_masked_images = processed_images * (processed_masks < 0.5) + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_masked_images, processed_masks + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.ndarray, + mask: jnp.ndarray, + masked_image: jnp.ndarray, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + latents_shape = ( + batch_size, + self.vae.config.latent_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + prng_seed, mask_prng_seed = jax.random.split(prng_seed) + + masked_image_latent_dist = self.vae.apply( + {"params": params["vae"]}, masked_image, method=self.vae.encode + ).latent_dist + masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + del mask_prng_seed + + mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") + + # 8. Check that sizes of mask, masked image and latents match + num_channels_latents = self.vae.config.latent_channels + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + def loop_body(step, args): + latents, mask, masked_image_latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + mask_input = jnp.concatenate([mask] * 2) + masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + # concat latents, mask, masked_image_latents in the channel dimension + latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, mask, masked_image_latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, mask, masked_image_latents, scheduler_state = loop_body( + i, (latents, mask, masked_image_latents, scheduler_state) + ) + else: + latents, _, _, _ = jax.lax.fori_loop( + 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) + ) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.ndarray, + mask: jnp.ndarray, + masked_image: jnp.ndarray, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + latents (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + array is generated by sampling using the supplied random `generator`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + + + This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a + future release. + + + + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") + mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 6, 7, 8), +) +def _p_generate( + pipe, + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess_image(image, dtype): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, dtype): + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w, h)) + mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 + mask = jnp.expand_dims(mask, axis=(0, 1)) + + return mask diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py new file mode 100755 index 0000000..2e34dcb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -0,0 +1,487 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) + + +class OnnxStableDiffusionPipeline(DiffusionPipeline): + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.Tensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + One or a list of [numpy generator(s)](TODO) to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # get the initial random noise unless the user supplied it + latents_dtype = prompt_embeds.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + ): + deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." + deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) + super().__init__( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py new file mode 100755 index 0000000..c394098 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -0,0 +1,549 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: Union[str, List[str]], + callback_steps: int, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + image = preprocess(image).cpu().numpy() + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + + if isinstance(prompt, str): + prompt = [prompt] + if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = len(prompt) // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) + elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." + ) + else: + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py new file mode 100755 index 0000000..18d8050 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -0,0 +1,563 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +NUM_UNET_INPUT_CHANNELS = 9 +NUM_LATENT_CHANNELS = 4 + + +def prepare_mask_and_masked_image(image, mask, latents_shape): + image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + image = image[None].transpose(0, 3, 1, 2) + image = image.astype(np.float32) / 127.5 - 1.0 + + image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + masked_image = image * (image_mask < 127.5) + + mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask, masked_image + + +class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + num_channels_latents = NUM_LATENT_CHANNELS + latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + latents_dtype = prompt_embeds.dtype + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # prepare mask and masked_image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) + mask = mask.astype(latents.dtype) + masked_image = masked_image.astype(latents.dtype) + + masked_image_latents = self.vae_encoder(sample=masked_image)[0] + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt + mask = mask.repeat(batch_size * num_images_per_prompt, 0) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) + + mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + + unet_input_channels = NUM_UNET_INPUT_CHANNELS + if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + # concat latents, mask, masked_image_latnets in the channel dimension + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py new file mode 100755 index 0000000..cd9ec57 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -0,0 +1,586 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers +from ...utils import deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) + + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + return image + + +class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): + vae: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + low_res_scheduler: DDPMScheduler + scheduler: KarrasDiffusionSchedulers + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: Any, + unet: OnnxRuntimeModel, + low_res_scheduler: DDPMScheduler, + scheduler: KarrasDiffusionSchedulers, + safety_checker: Optional[OnnxRuntimeModel] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + max_noise_level: int = 350, + num_latent_channels=4, + num_unet_input_channels=7, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + low_res_scheduler=low_res_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config( + max_noise_level=max_noise_level, + num_latent_channels=num_latent_channels, + num_unet_input_channels=num_unet_input_channels, + ) + + def check_inputs( + self, + prompt: Union[str, List[str]], + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, (list, np.ndarray)): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = generator.randn(*shape).astype(dtype) + elif latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + return latents + + def decode_latents(self, latents): + latents = 1 / 0.08333 * latents + image = self.vae(latent_sample=latents)[0] + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + return image + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[np.ndarray, PIL.Image.Image, List[PIL.Image.Image]], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[np.random.RandomState, List[np.random.RandomState]]] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + noise_level (`float`, defaults to 0.2): + Deteremines the amount of noise to add to the initial image before performing upscaling. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = preprocess(image).cpu().numpy() + height, width = image.shape[2:] + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.config.num_latent_channels, + height, + width, + latents_dtype, + generator, + ) + image = image.astype(latents_dtype) + + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # 5. Add noise to image + noise_level = np.array([noise_level]).astype(np.int64) + noise = generator.randn(*image.shape).astype(latents_dtype) + + image = self.low_res_scheduler.add_noise( + torch.from_numpy(image), torch.from_numpy(noise), torch.from_numpy(noise_level) + ) + image = image.numpy() + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) + noise_level = np.concatenate([noise_level] * image.shape[0]) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if self.config.num_latent_channels + num_channels_image != self.config.num_unet_input_channels: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {self.config.num_unet_input_channels} but received `num_channels_latents`: {self.config.num_latent_channels} +" + f" `num_channels_image`: {num_channels_image} " + f" = {self.config.num_latent_channels + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = np.concatenate([latent_model_input, image], axis=1) + + # timestep to tensor + timestep = np.array([t], dtype=timestep_dtype) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + latents = latents.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 10. Post-processing + image = self.decode_latents(latents) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_output.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_output.py new file mode 100755 index 0000000..5fb9b1a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_output.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput, is_flax_available + + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`List[bool]`) + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Flax-based Stable Diffusion pipelines. + + Args: + images (`np.ndarray`): + Denoised images of array shape of `(batch_size, height, width, num_channels)`. + nsfw_content_detected (`List[bool]`): + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content + or `None` if safety checking could not be performed. + """ + + images: np.ndarray + nsfw_content_detected: List[bool] diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py new file mode 100755 index 0000000..a2bbec7 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,1072 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None + ) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py new file mode 100755 index 0000000..7801b0d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -0,0 +1,875 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin): + r""" + Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "depth_mask"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + depth_estimator: DPTForDepthEstimation, + feature_extractor: DPTImageProcessor, + ): + super().__init__() + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + depth_estimator=depth_estimator, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): + if isinstance(image, PIL.Image.Image): + image = [image] + else: + image = list(image) + + if isinstance(image[0], PIL.Image.Image): + width, height = image[0].size + elif isinstance(image[0], np.ndarray): + width, height = image[0].shape[:-1] + else: + height, width = image[0].shape[-2:] + + if depth_map is None: + pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=device, dtype=dtype) + # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. + # So we use `torch.autocast` here for half precision inference. + if torch.backends.mps.is_available(): + autocast_ctx = contextlib.nullcontext() + logger.warning( + "The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS." + ) + else: + autocast_ctx = torch.autocast(device.type, dtype=dtype) + + with autocast_ctx: + depth_map = self.depth_estimator(pixel_values).predicted_depth + else: + depth_map = depth_map.to(device=device, dtype=dtype) + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(height // self.vae_scale_factor, width // self.vae_scale_factor), + mode="bicubic", + align_corners=False, + ) + + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 + depth_map = depth_map.to(dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if depth_map.shape[0] < batch_size: + repeat_by = batch_size // depth_map.shape[0] + depth_map = depth_map.repeat(repeat_by, 1, 1, 1) + + depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map + return depth_map + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + depth_map: Optional[torch.Tensor] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can accept image + latents as `image` only if `depth_map` is not `None`. + depth_map (`torch.Tensor`, *optional*): + Depth prediction to be used as additional conditioning for the image generation process. If not + defined, it automatically predicts the depth with `self.depth_estimator`. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + ```py + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> from diffusers import StableDiffusionDepth2ImgPipeline + + >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-depth", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.to("cuda") + + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> init_image = Image.open(requests.get(url, stream=True).raw) + >>> prompt = "two tigers" + >>> n_prompt = "bad, deformed, ugly, bad anotomy" + >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 1. Check inputs + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare depth mask + depth_mask = self.prepare_depth_map( + image, + depth_map, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + prompt_embeds.dtype, + device, + ) + + # 5. Preprocess image + image = self.image_processor.preprocess(image) + + # 6. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + depth_mask = callback_outputs.pop("depth_mask", depth_mask) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py new file mode 100755 index 0000000..93a8bd1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -0,0 +1,425 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline to generate image variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + # TODO: feature_extractor is required to encode images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + model_cpu_offload_seq = "image_encoder->unet->vae" + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + from diffusers import StableDiffusionImageVariationPipeline + from PIL import Image + from io import BytesIO + import requests + + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", revision="v2.0" + ) + pipe = pipe.to("cuda") + + url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + + out = pipe(image, num_images_per_prompt=3, guidance_scale=15) + out["images"][0].save("result.jpg") + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + self.maybe_free_model_hooks() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py new file mode 100755 index 0000000..424f0e3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -0,0 +1,1145 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" + >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py new file mode 100755 index 0000000..e2c5b11 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -0,0 +1,1338 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] + + def __init__( + self, + vae: Union[AutoencoderKL, AsymmetricAutoencoderKL], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 9: + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to + be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch + tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the + expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the + expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but + if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs + )[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py new file mode 100755 index 0000000..fd89b19 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -0,0 +1,906 @@ +# Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionInstructPix2PixPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 100, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept + image latents as `image`, but if passing latents directly it is not encoded again. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Push the generated image towards the initial `image`. Image guidance scale is enabled by setting + `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely + linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a + value of at least `1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInstructPix2PixPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + + >>> image = download_image(img_url).resize((512, 512)) + + >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( + ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "make the mountains snowy" + >>> image = pipe(prompt=prompt, image=image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Check inputs + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._image_guidance_scale = image_guidance_scale + + device = self._execution_device + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 2. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + self.do_classifier_free_guidance, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + self.guidance_scale * (noise_pred_text - noise_pred_image) + + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + image_latents = callback_outputs.pop("image_latents", image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + else: + prompt_embeds_dtype = self.unet.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat( + [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] + ) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + ( + single_image_embeds, + single_negative_image_embeds, + single_negative_image_embeds, + ) = single_image_embeds.chunk(3) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat( + [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] + ) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py new file mode 100755 index 0000000..ffe02ae --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -0,0 +1,655 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin): + r""" + Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A [`EulerDiscreteScheduler`] to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: EulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") + + def _encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + **kwargs, + ) + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None or pooled_prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_out = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + prompt_embeds = text_encoder_out.hidden_states[-1] + pooled_prompt_embeds = text_encoder_out.pooler_output + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + + uncond_encoder_out = self.text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = uncond_encoder_out.hidden_states[-1] + negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, np.ndarray) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, (list, torch.Tensor)): + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] if image.ndim == 4 else 1 + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image upscaling. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be upscaled. If it's a tensor, it can be either a + latent output from a Stable Diffusion model or an image tensor in the range `[-1, 1]`. It is considered + a `latent` if `image.shape[1]` is `4`; otherwise, it is considered to be an image representation and + encoded using this pipeline's `vae` encoder. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + ```py + >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline + >>> import torch + + + >>> pipeline = StableDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 + ... ) + >>> pipeline.to("cuda") + + >>> model_id = "stabilityai/sd-x2-latent-upscaler" + >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> upscaler.to("cuda") + + >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" + >>> generator = torch.manual_seed(33) + + >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images + + >>> with torch.no_grad(): + ... image = pipeline.decode_latents(low_res_latents) + >>> image = pipeline.numpy_to_pil(image)[0] + + >>> image.save("../images/a1.png") + + >>> upscaled_image = upscaler( + ... prompt=prompt, + ... image=low_res_latents, + ... num_inference_steps=20, + ... guidance_scale=0, + ... generator=generator, + ... ).images[0] + + >>> upscaled_image.save("../images/a2.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None: + batch_size = 1 if isinstance(prompt, str) else len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if guidance_scale == 0: + prompt = [""] * batch_size + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + image = image.to(dtype=prompt_embeds.dtype, device=device) + if image.shape[1] == 3: + # encode image if not in latent-space yet + image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = image[None, :] if image.ndim == 3 else image + image = torch.cat([image] * batch_multiplier) + + # 5. Add noise to image (set to be 0): + # (see below notes from the author): + # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." + noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) + noise_level = torch.cat([noise_level] * image.shape[0]) + inv_noise_level = (noise_level**2 + 1) ** (-0.5) + + image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] + image_cond = image_cond.to(prompt_embeds.dtype) + + noise_level_embed = torch.cat( + [ + torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device), + torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device), + ], + dim=1, + ) + + timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height * 2, # 2x upscale + width * 2, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Denoising loop + num_warmup_steps = 0 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + sigma = self.scheduler.sigmas[i] + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) + # preconditioning parameter based on Karras et al. (2022) (table 1) + timestep = torch.log(sigma) * 0.25 + + noise_pred = self.unet( + scaled_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_condition, + ).sample + + # in original repo, the output contains a variance channel that's not used + noise_pred = noise_pred[:, :-1] + + # apply preconditioning, based on table 1 in Karras et al. (2022) + inv_sigma = 1 / (sigma**2 + 1) + noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py new file mode 100755 index 0000000..4cbbe17 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -0,0 +1,809 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionUpscalePipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image super-resolution using Stable Diffusion 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + low_res_scheduler ([`SchedulerMixin`]): + A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of + [`DDPMScheduler`]. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["watermarker", "safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + scheduler: KarrasDiffusionSchedulers, + safety_checker: Optional[Any] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + watermarker: Optional[Any] = None, + max_noise_level: int = 350, + ): + super().__init__() + + if hasattr( + vae, "config" + ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + is_vae_scaling_factor_set_to_0_08333 = ( + hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 + ) + if not is_vae_scaling_factor_set_to_0_08333: + deprecation_message = ( + "The configuration file of the vae does not contain `scaling_factor` or it is set to" + f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" + " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" + " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" + " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" + ) + deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) + vae.register_to_config(scaling_factor=0.08333) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + safety_checker=safety_checker, + watermarker=watermarker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") + self.register_to_config(max_noise_level=max_noise_level) + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, (list, np.ndarray, torch.Tensor)): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import StableDiffusionUpscalePipeline + >>> import torch + + >>> # load model and scheduler + >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" + >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( + ... model_id, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + >>> prompt = "a white cat" + + >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] + >>> upscaled_image.save("upsampled_cat.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = torch.cat([image] * batch_multiplier * num_images_per_prompt) + noise_level = torch.cat([noise_level] * image.shape[0]) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=noise_level, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + + # Ensure latents are always the same type as the VAE + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 11. Apply watermark + if output_type == "pil" and self.watermarker is not None: + image = self.watermarker.apply_watermark(image) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py new file mode 100755 index 0000000..41811f8 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -0,0 +1,940 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableUnCLIPPipeline + + >>> pipe = StableUnCLIPPipeline.from_pretrained( + ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 + ... ) # TODO update model path + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> images = pipe(prompt).images + >>> images[0].save("astronaut_horse.png") + ``` +""" + + +class StableUnCLIPPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + """ + Pipeline for text-to-image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + prior_tokenizer ([`CLIPTokenizer`]): + A [`CLIPTokenizer`]. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen [`CLIPTextModelWithProjection`] text-encoder. + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_scheduler ([`KarrasDiffusionSchedulers`]): + Scheduler used in the prior denoising process. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by the `noise_level`. + tokenizer ([`CLIPTokenizer`]): + A [`CLIPTokenizer`]. + text_encoder ([`CLIPTextModel`]): + Frozen [`CLIPTextModel`] text-encoder. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + _exclude_from_cpu_offload = ["prior", "image_normalizer"] + model_cpu_offload_seq = "text_encoder->prior_text_encoder->unet->vae" + + # prior components + prior_tokenizer: CLIPTokenizer + prior_text_encoder: CLIPTextModelWithProjection + prior: PriorTransformer + prior_scheduler: KarrasDiffusionSchedulers + + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + + # regular denoising components + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + + vae: AutoencoderKL + + def __init__( + self, + # prior components + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModelWithProjection, + prior: PriorTransformer, + prior_scheduler: KarrasDiffusionSchedulers, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_encoder, + prior=prior, + prior_scheduler=prior_scheduler, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder + def _encode_prior_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.prior_tokenizer( + prompt, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.prior_tokenizer.batch_decode( + untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] + + prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) + + prompt_embeds = prior_text_encoder_output.text_embeds + text_enc_hid_states = prior_text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.prior_tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( + uncond_input.input_ids.to(device) + ) + + negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds + uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_enc_hid_states.shape[1] + uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) + uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_enc_hid_states, text_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler + def prepare_prior_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the prior_scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways: + 1. A noise schedule is applied directly to the embeddings. + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + self.image_normalizer.to(image_embeds.device) + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + # regular denoising process args + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + # prior args + prior_num_inference_steps: int = 25, + prior_guidance_scale: float = 4.0, + prior_latents: Optional[torch.Tensor] = None, + clip_skip: Optional[int] = None, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps in the prior denoising process. More denoising steps usually lead to a + higher quality image at the expense of slower inference. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + prior_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + embedding generation in the prior denoising process. Can be used to tweak the same generation with + different prompts. If not provided, a latents tensor is generated by sampling using the supplied random + `generator`. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning + a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 + + # 3. Encode input prompt + prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=prior_do_classifier_free_guidance, + ) + + # 4. Prepare prior timesteps + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + # 5. Prepare prior latent variables + embedding_dim = self.prior.config.embedding_dim + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prior_prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) + + # 7. Prior denoising loop + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents + latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prior_prompt_embeds, + encoder_hidden_states=prior_text_encoder_hidden_states, + attention_mask=prior_text_mask, + ).predicted_image_embedding + + if prior_do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + **prior_extra_step_kwargs, + return_dict=False, + )[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, prior_latents) + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeds = prior_latents + + # done prior + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 8. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 9. Prepare image embeddings + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + # 10. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 11. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + latents = self.prepare_latents( + shape=shape, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + scheduler=self.scheduler, + ) + + # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 13. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py new file mode 100755 index 0000000..2556d5e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -0,0 +1,846 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableUnCLIPImg2ImgPipeline + + >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(init_image, prompt).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +class StableUnCLIPImg2ImgPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + """ + Pipeline for text-guided image-to-image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + feature_extractor ([`CLIPImageProcessor`]): + Feature extractor for image pre-processing before being encoded. + image_encoder ([`CLIPVisionModelWithProjection`]): + CLIP vision model for encoding images. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by the `noise_level`. + tokenizer (`~transformers.CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`)]. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen [`~transformers.CLIPTextModel`] text-encoder. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _exclude_from_cpu_offload = ["image_normalizer"] + + # image encoding components + feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + + # regular denoising components + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + + vae: AutoencoderKL + + def __init__( + self, + # image encoding components + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + feature_extractor=feature_extractor, + image_encoder=image_encoder, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def _encode_image( + self, + image, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + noise_level, + generator, + image_embeds, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if isinstance(image, PIL.Image.Image): + # the image embedding should repeated so it matches the total batch size of the prompt + repeat_by = batch_size + else: + # assume the image input is already properly batched and just needs to be repeated so + # it matches the num_images_per_prompt. + # + # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched + # `image_embeds`. If those happen to be common use cases, let's think harder about + # what the expected dimensions of inputs should be and how we handle the encoding. + repeat_by = num_images_per_prompt + + if image_embeds is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeds = image_embeds.unsqueeze(1) + bs_embed, seq_len, _ = image_embeds.shape + image_embeds = image_embeds.repeat(1, repeat_by, 1) + image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) + image_embeds = image_embeds.squeeze(1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + if image is not None and image_embeds is not None: + raise ValueError( + "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." + ) + + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + + if image is not None: + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways: + 1. A noise schedule is applied directly to the embeddings. + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + self.image_normalizer.to(image_embeds.device) + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[torch.Tensor, PIL.Image.Image] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + image_embeds: Optional[torch.Tensor] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be + used or prompt is initialized to `""`. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image` or tensor representing an image batch. The image is encoded to its CLIP embedding which the + `unet` is conditioned on. The image is _not_ encoded by the `vae` and then used as the latents in the + denoising process like it is in the standard Stable Diffusion text-guided image variation process. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated CLIP embeddings to condition the `unet` on. These latents are not used in the denoising + process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning + a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if prompt is None and prompt_embeds is None: + prompt = len(image) * [""] if isinstance(image, list) else "" + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Encoder input image + noise_level = torch.tensor([noise_level], device=device) + image_embeds = self._encode_image( + image=image, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + noise_level=noise_level, + generator=generator, + image_embeds=image_embeds, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/diffusers/src/diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100755 index 0000000..3a0e864 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/diffusers/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py new file mode 100755 index 0000000..571a4f2 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -0,0 +1,112 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPConfig, FlaxPreTrainedModel +from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule + + +def jax_cosine_distance(emb_1, emb_2, eps=1e-12): + norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T + norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T + return jnp.matmul(norm_emb_1, norm_emb_2.T) + + +class FlaxStableDiffusionSafetyCheckerModule(nn.Module): + config: CLIPConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) + self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) + + self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) + self.special_care_embeds = self.param( + "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) + ) + + self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) + self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) + + def __call__(self, clip_input): + pooled_output = self.vision_model(clip_input)[1] + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign image inputs + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment + special_scores = jnp.round(special_scores, 3) + is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) + # Use a lower threshold if an image has any special care concept + special_adjustment = is_special_care * 0.01 + + concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment + concept_scores = jnp.round(concept_scores, 3) + has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) + + return has_nsfw_concepts + + +class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + module_class = FlaxStableDiffusionSafetyCheckerModule + + def __init__( + self, + config: CLIPConfig, + input_shape: Optional[Tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if input_shape is None: + input_shape = (1, 224, 224, 3) + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensor + clip_input = jax.random.normal(rng, input_shape) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, clip_input)["params"] + + return random_params + + def __call__( + self, + clip_input, + params: dict = None, + ): + clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) + + return self.module.apply( + {"params": params or self.params}, + jnp.array(clip_input, dtype=jnp.float32), + rngs={}, + ) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/diffusers/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py new file mode 100755 index 0000000..3fc6b3a --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): + """ + This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. + + It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image + embeddings. + """ + + @register_to_config + def __init__( + self, + embedding_dim: int = 768, + ): + super().__init__() + + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + ): + self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) + self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) + return self + + def scale(self, embeds): + embeds = (embeds - self.mean) * 1.0 / self.std + return embeds + + def unscale(self, embeds): + embeds = (embeds * self.std) + self.mean + return embeds diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_3/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_3/__init__.py new file mode 100755 index 0000000..b060458 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_3/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["StableDiffusion3PipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"] + _import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"] + _import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline + from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline + from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_output.py b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_output.py new file mode 100755 index 0000000..4655f44 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class StableDiffusion3PipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py new file mode 100755 index 0000000..5a10f32 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -0,0 +1,935 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3Pipeline + + >>> pipe = StableDiffusion3Pipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("sd3.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py new file mode 100755 index 0000000..96d5366 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -0,0 +1,967 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers" + >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels + ) + self.tokenizer_max_length = self.tokenizer.model_max_length + self.default_sample_size = self.transformer.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if image.shape[1] == self.vae.config.latent_channels: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + latents = init_latents.to(device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.6, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py new file mode 100755 index 0000000..440b652 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -0,0 +1,1201 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3InpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusion3InpaintPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("sd3_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = self.tokenizer.model_max_length + self.default_sample_size = self.transformer.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 16: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: int = None, + width: int = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 4. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_transformer = self.transformer.config.in_channels + return_image_latents = num_channels_transformer == 16 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 6. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # match the inpainting pipeline and will be updated with input + mask inpainting model later + if num_channels_transformer == 33: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.transformer.config.in_channels + ): + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" + f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.transformer` or your `mask_image` or `image` input." + ) + elif num_channels_transformer != 16: + raise ValueError( + f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}." + ) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if num_channels_transformer == 33: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if num_channels_transformer == 16: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + else: + image = latents + + do_denormalize = [True] * image.shape[0] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py new file mode 100755 index 0000000..cce556f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/diffusers/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py new file mode 100755 index 0000000..8f40fa7 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py @@ -0,0 +1,1097 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.nn import functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import Attention +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionAttendAndExcitePipeline + + >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> prompt = "a cat and a frog" + + >>> # use get_indices function to find out indices of the tokens you want to alter + >>> pipe.get_indices(prompt) + {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'} + + >>> token_indices = [2, 5] + >>> seed = 6141 + >>> generator = torch.Generator("cuda").manual_seed(seed) + + >>> images = pipe( + ... prompt=prompt, + ... token_indices=token_indices, + ... guidance_scale=7.5, + ... generator=generator, + ... num_inference_steps=50, + ... max_iter_to_alter=25, + ... ).images + + >>> image = images[0] + >>> image.save(f"../images/{prompt}_{seed}.png") + ``` +""" + + +class AttentionStore: + @staticmethod + def get_empty_store(): + return {"down": [], "mid": [], "up": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= 0 and is_cross: + if attn.shape[1] == np.prod(self.attn_res): + self.step_store[place_in_unet].append(attn) + + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.between_steps() + + def between_steps(self): + self.attention_store = self.step_store + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = self.attention_store + return average_attention + + def aggregate_attention(self, from_where: List[str]) -> torch.Tensor: + """Aggregates the attention across the different layers and heads at the specified resolution.""" + out = [] + attention_maps = self.get_average_attention() + for location in from_where: + for item in attention_maps[location]: + cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1]) + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out + + def reset(self): + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self, attn_res): + """ + Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion + process + """ + self.num_att_layers = -1 + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + self.curr_step_index = 0 + self.attn_res = attn_res + + +class AttendExciteAttnProcessor: + def __init__(self, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + + # only need to store attention maps during the Attend and Excite process + if attention_probs.requires_grad: + self.attnstore(attention_probs, is_cross, self.place_in_unet) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion and Attend-and-Excite. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + indices, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int) + indices_is_list_list_ints = ( + isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int) + ) + + if not indices_is_list_ints and not indices_is_list_list_ints: + raise TypeError("`indices` must be a list of ints or a list of a list of ints") + + if indices_is_list_ints: + indices_batch_size = 1 + elif indices_is_list_list_ints: + indices_batch_size = len(indices) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if indices_batch_size != prompt_batch_size: + raise ValueError( + f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @staticmethod + def _compute_max_attention_per_index( + attention_maps: torch.Tensor, + indices: List[int], + ) -> List[torch.Tensor]: + """Computes the maximum attention value for each of the tokens we wish to alter.""" + attention_for_text = attention_maps[:, :, 1:-1] + attention_for_text *= 100 + attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) + + # Shift indices since we removed the first token + indices = [index - 1 for index in indices] + + # Extract the maximum values + max_indices_list = [] + for i in indices: + image = attention_for_text[:, :, i] + smoothing = GaussianSmoothing().to(attention_maps.device) + input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") + image = smoothing(input).squeeze(0).squeeze(0) + max_indices_list.append(image.max()) + return max_indices_list + + def _aggregate_and_get_max_attention_per_token( + self, + indices: List[int], + ): + """Aggregates the attention for each token and computes the max activation value for each token to alter.""" + attention_maps = self.attention_store.aggregate_attention( + from_where=("up", "down", "mid"), + ) + max_attention_per_index = self._compute_max_attention_per_index( + attention_maps=attention_maps, + indices=indices, + ) + return max_attention_per_index + + @staticmethod + def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor: + """Computes the attend-and-excite loss using the maximum attention value for each token.""" + losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index] + loss = max(losses) + return loss + + @staticmethod + def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: + """Update the latent according to the computed loss.""" + grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] + latents = latents - step_size * grad_cond + return latents + + def _perform_iterative_refinement_step( + self, + latents: torch.Tensor, + indices: List[int], + loss: torch.Tensor, + threshold: float, + text_embeddings: torch.Tensor, + step_size: float, + t: int, + max_refinement_steps: int = 20, + ): + """ + Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code + according to our loss objective until the given threshold is reached for all tokens. + """ + iteration = 0 + target_loss = max(0, 1.0 - threshold) + while loss > target_loss: + iteration += 1 + + latents = latents.clone().detach().requires_grad_(True) + self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=indices, + ) + + loss = self._compute_loss(max_attention_per_index) + + if loss != 0: + latents = self._update_latent(latents, loss, step_size) + + logger.info(f"\t Try {iteration}. loss: {loss}") + + if iteration >= max_refinement_steps: + logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ") + break + + # Run one more time but don't compute gradients and update the latents. + # We just need to compute the new loss - the grad update will occur below + latents = latents.clone().detach().requires_grad_(True) + _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=indices, + ) + loss = self._compute_loss(max_attention_per_index) + logger.info(f"\t Finished with loss of: {loss}") + return loss, latents, max_attention_per_index + + def register_attention_control(self): + attn_procs = {} + cross_att_count = 0 + for name in self.unet.attn_processors.keys(): + if name.startswith("mid_block"): + place_in_unet = "mid" + elif name.startswith("up_blocks"): + place_in_unet = "up" + elif name.startswith("down_blocks"): + place_in_unet = "down" + else: + continue + + cross_att_count += 1 + attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) + + self.unet.set_attn_processor(attn_procs) + self.attention_store.num_att_layers = cross_att_count + + def get_indices(self, prompt: str) -> Dict[str, int]: + """Utility function to list the indices of the tokens you wish to alte""" + ids = self.tokenizer(prompt).input_ids + indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} + return indices + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + token_indices: Union[List[int], List[List[int]]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + max_iter_to_alter: int = 25, + thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, + scale_factor: int = 20, + attn_res: Optional[Tuple[int]] = (16, 16), + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + token_indices (`List[int]`): + The token indices to alter with attend-and-excite. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + max_iter_to_alter (`int`, *optional*, defaults to `25`): + Number of denoising steps to apply attend-and-excite. The `max_iter_to_alter` denoising steps are when + attend-and-excite is applied. For example, if `max_iter_to_alter` is `25` and there are a total of `30` + denoising steps, the first `25` denoising steps applies attend-and-excite and the last `5` will not. + thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`): + Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. + scale_factor (`int`, *optional*, default to 20): + Scale factor to control the step size of each attend-and-excite update. + attn_res (`tuple`, *optional*, default computed from width and height): + The 2D resolution of the semantic attention map. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + token_indices, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if attn_res is None: + attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) + self.attention_store = AttentionStore(attn_res) + original_attn_proc = self.unet.attn_processors + self.register_attention_control() + + # default config for step size from original repo + scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps)) + step_size = scale_factor * np.sqrt(scale_range) + + text_embeddings = ( + prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds + ) + + if isinstance(token_indices[0], int): + token_indices = [token_indices] + + indices = [] + + for ind in token_indices: + indices = indices + [ind] * num_images_per_prompt + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Attend and excite process + with torch.enable_grad(): + latents = latents.clone().detach().requires_grad_(True) + updated_latents = [] + for latent, index, text_embedding in zip(latents, indices, text_embeddings): + # Forward pass of denoising with text conditioning + latent = latent.unsqueeze(0) + text_embedding = text_embedding.unsqueeze(0) + + self.unet( + latent, + t, + encoder_hidden_states=text_embedding, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=index, + ) + + loss = self._compute_loss(max_attention_per_index=max_attention_per_index) + + # If this is an iterative refinement step, verify we have reached the desired threshold for all + if i in thresholds.keys() and loss > 1.0 - thresholds[i]: + loss, latent, max_attention_per_index = self._perform_iterative_refinement_step( + latents=latent, + indices=index, + loss=loss, + threshold=thresholds[i], + text_embeddings=text_embedding, + step_size=step_size[i], + t=t, + ) + + # Perform gradient update + if i < max_iter_to_alter: + if loss != 0: + latent = self._update_latent( + latents=latent, + loss=loss, + step_size=step_size[i], + ) + logger.info(f"Iteration {i} | Loss: {loss:0.4f}") + + updated_latents.append(latent) + + latents = torch.cat(updated_latents, dim=0) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class GaussianSmoothing(torch.nn.Module): + """ + Arguments: + Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input + using a depthwise convolution. + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the + gaussian kernel. dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2 + def __init__( + self, + channels: int = 1, + kernel_size: int = 3, + sigma: float = 0.5, + dim: int = 2, + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * dim + if isinstance(sigma, float): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) + + def forward(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py new file mode 100755 index 0000000..e2145ed --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/diffusers/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py new file mode 100755 index 0000000..2b86470 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -0,0 +1,1531 @@ +# Copyright 2024 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class DiffEditInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.Tensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, + batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the + diffusion pipeline. + """ + + latents: torch.Tensor + images: Union[List[PIL.Image.Image], np.ndarray] + + +EXAMPLE_DOC_STRING = """ + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipeline.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipeline.invert(image=init_image, prompt=mask_prompt).latents + >>> image = pipeline(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipeline.invert(image=init_image, prompt=prompt).latents + ``` +""" + + +def auto_corr_loss(hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) + return reg_loss + + +def kl_divergence(hidden_states): + return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def preprocess_mask(mask, batch_size: int = 1): + if not isinstance(mask, torch.Tensor): + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list): + if isinstance(mask[0], PIL.Image.Image): + mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] + if isinstance(mask[0], np.ndarray): + mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + # Check mask shape + if batch_size > 1: + if mask.shape[0] == 1: + mask = torch.cat([mask] * batch_size) + elif mask.shape[0] > 1 and mask.shape[0] != batch_size: + raise ValueError( + f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " + f"inferred by prompt inputs" + ) + + if mask.shape[1] != 1: + raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("`mask_image` should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask + + +class StableDiffusionDiffEditPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + + + This is an experimental feature! + + + + Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading and saving methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + inverse_scheduler ([`DDIMInverseScheduler`]): + A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + inverse_scheduler: DDIMInverseScheduler, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): + raise ValueError( + f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def check_source_inputs( + self, + source_prompt=None, + source_negative_prompt=None, + source_prompt_embeds=None, + source_negative_prompt_embeds=None, + ): + if source_prompt is not None and source_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." + " Please make sure to only forward one of the two." + ) + elif source_prompt is None and source_prompt_embeds is None: + raise ValueError( + "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." + ) + elif source_prompt is not None and ( + not isinstance(source_prompt, str) and not isinstance(source_prompt, list) + ): + raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") + + if source_negative_prompt is not None and source_negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" + f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: + if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: + raise ValueError( + "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" + f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def get_inverse_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == 0: + return self.inverse_scheduler.timesteps, num_inference_steps + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def generate_mask( + self, + image: Union[torch.Tensor, PIL.Image.Image] = None, + target_prompt: Optional[Union[str, List[str]]] = None, + target_negative_prompt: Optional[Union[str, List[str]]] = None, + target_prompt_embeds: Optional[torch.Tensor] = None, + target_negative_prompt_embeds: Optional[torch.Tensor] = None, + source_prompt: Optional[Union[str, List[str]]] = None, + source_negative_prompt: Optional[Union[str, List[str]]] = None, + source_prompt_embeds: Optional[torch.Tensor] = None, + source_negative_prompt_embeds: Optional[torch.Tensor] = None, + num_maps_per_mask: Optional[int] = 10, + mask_encode_strength: Optional[float] = 0.5, + mask_thresholding_ratio: Optional[float] = 3.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "np", + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Generate a latent mask given a mask prompt, a target prompt, and an image. + + Args: + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be used for computing the mask. + target_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide semantic mask generation. If not defined, you need to pass + `prompt_embeds`. + target_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + target_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + target_negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + source_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide semantic mask generation using DiffEdit. If not defined, you need to + pass `source_prompt_embeds` or `source_image` instead. + source_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide semantic mask generation away from using DiffEdit. If not defined, you + need to pass `source_negative_prompt_embeds` or `source_image` instead. + source_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text + inputs (prompt weighting). If not provided, text embeddings are generated from `source_prompt` input + argument. + source_negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily + tweak text inputs (prompt weighting). If not provided, text embeddings are generated from + `source_negative_prompt` input argument. + num_maps_per_mask (`int`, *optional*, defaults to 10): + The number of noise maps sampled to generate the semantic mask using DiffEdit. + mask_encode_strength (`float`, *optional*, defaults to 0.5): + The strength of the noise maps sampled to generate the semantic mask using DiffEdit. Must be between 0 + and 1. + mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): + The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before + mask binarization. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + `List[PIL.Image.Image]` or `np.array`: + When returning a `List[PIL.Image.Image]`, the list consists of a batch of single-channel binary images + with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`. If it's + `np.array`, the shape is `(batch_size, height // self.vae_scale_factor, width // + self.vae_scale_factor)`. + """ + + # 1. Check inputs (Provide dummy argument for callback_steps) + self.check_inputs( + target_prompt, + mask_encode_strength, + 1, + target_negative_prompt, + target_prompt_embeds, + target_negative_prompt_embeds, + ) + + self.check_source_inputs( + source_prompt, + source_negative_prompt, + source_prompt_embeds, + source_negative_prompt_embeds, + ) + + if (num_maps_per_mask is None) or ( + num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) + ): + raise ValueError( + f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" + f" {type(num_maps_per_mask)}." + ) + + if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: + raise ValueError( + f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" + f" {type(mask_thresholding_ratio)}." + ) + + # 2. Define call parameters + if target_prompt is not None and isinstance(target_prompt, str): + batch_size = 1 + elif target_prompt is not None and isinstance(target_prompt, list): + batch_size = len(target_prompt) + else: + batch_size = target_prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) + target_negative_prompt_embeds, target_prompt_embeds = self.encode_prompt( + target_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + target_negative_prompt, + prompt_embeds=target_prompt_embeds, + negative_prompt_embeds=target_negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + target_prompt_embeds = torch.cat([target_negative_prompt_embeds, target_prompt_embeds]) + + source_negative_prompt_embeds, source_prompt_embeds = self.encode_prompt( + source_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + source_negative_prompt, + prompt_embeds=source_prompt_embeds, + negative_prompt_embeds=source_negative_prompt_embeds, + ) + if do_classifier_free_guidance: + source_prompt_embeds = torch.cat([source_negative_prompt_embeds, source_prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) + encode_timestep = timesteps[0] + + # 6. Prepare image latents and add noise with specified strength + image_latents = self.prepare_image_latents( + image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator + ) + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) + image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) + + latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) + + # 7. Predict the noise residual + prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) + noise_pred = self.unet( + latent_model_input, + encode_timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if do_classifier_free_guidance: + noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) + noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) + else: + noise_pred_source, noise_pred_target = noise_pred.chunk(2) + + # 8. Compute the mask from the absolute difference of predicted noise residuals + # TODO: Consider smoothing mask guidance map + mask_guidance_map = ( + torch.abs(noise_pred_target - noise_pred_source) + .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) + .mean([1, 2]) + ) + clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio + semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude + semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) + mask_image = semantic_mask_image.cpu().numpy() + + # 9. Convert to Numpy array or PIL. + if output_type == "pil": + mask_image = self.image_processor.numpy_to_pil(mask_image) + + # Offload all models + self.maybe_free_model_hooks() + + return mask_image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + num_inference_steps: int = 50, + inpaint_strength: float = 0.8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + decode_latents: bool = False, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to produce the inverted latents guided by `prompt`. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent of the noising process to run latent inversion. Must be between 0 and 1. When + `inpaint_strength` is 1, the inversion process is run for the full number of iterations specified in + `num_inference_steps`. `image` is used as a reference for the inversion process, and adding more noise + increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + decode_latents (`bool`, *optional*, defaults to `False`): + Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` + decodes all inverted latents for each timestep into a list of generated images. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction. + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback-Leibler divergence output. + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps. + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or + `tuple`: + If `return_dict` is `True`, + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is the inverted latents tensors + ordered by increasing noise, and the second is the corresponding decoded images if `decode_latents` is + `True`, otherwise `None`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + num_images_per_prompt = 1 + latents = self.prepare_image_latents( + image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator + ) + + # 5. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 6. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) + + # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + inverted_latents = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) + if num_reg_steps > 0: + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + inverted_latents.append(latents.detach().clone()) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + assert len(inverted_latents) == len(timesteps) + latents = torch.stack(list(reversed(inverted_latents)), 1) + + # 8. Post-processing + image = None + if decode_latents: + image = self.decode_latents(latents.flatten(0, 1)) + + # 9. Convert to PIL. + if decode_latents and output_type == "pil": + image = self.image_processor.numpy_to_pil(image) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (latents, image) + + return DiffEditInversionPipelineOutput(latents=latents, images=image) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + image_latents: Union[torch.Tensor, PIL.Image.Image] = None, + inpaint_strength: Optional[float] = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask the generated image. White pixels in the mask are + repainted, while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, 1, H, W)`. + image_latents (`PIL.Image.Image` or `torch.Tensor`): + Partially noised image latents from the inversion process to be used as inputs for image generation. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent to inpaint the masked area. Must be between 0 and 1. When `inpaint_strength` is 1, the + denoising process is run on the masked area for the full number of iterations specified in + `num_inference_steps`. `image_latents` is used as a reference for the masked area, and adding more + noise to a region increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." + ) + if image_latents is None: + raise ValueError( + "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess mask + mask_image = preprocess_mask(mask_image, batch_size) + latent_height, latent_width = mask_image.shape[-2:] + mask_image = torch.cat([mask_image] * num_images_per_prompt) + mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + + # 6. Preprocess image latents + if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents): + image_latents = torch.cat(image_latents).detach() + elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5: + image_latents = image_latents.detach() + else: + image_latents = self.image_processor.preprocess(image_latents).detach() + + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) + if image_latents.shape[-3:] != latent_shape: + raise ValueError( + f"Each latent image in `image_latents` must have shape {latent_shape}, " + f"but has shape {image_latents.shape[-3:]}" + ) + if image_latents.ndim == 4: + image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) + if image_latents.shape[:2] != (batch_size, len(timesteps)): + raise ValueError( + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}" + f" timesteps, but has batch size {image_latents.shape[0]} with latent images from" + f" {image_latents.shape[1]} timesteps." + ) + image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) + image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + latents = image_latents[0].clone() + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # mask with inverted latents from appropriate timestep - use original image latent for last step + latents = latents * mask_image + image_latents[i] * (1 - mask_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/__init__.py new file mode 100755 index 0000000..147980c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] + _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline + from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py new file mode 100755 index 0000000..52ccd56 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -0,0 +1,851 @@ +# Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionGLIGENPipeline + >>> from diffusers.utils import load_image + + >>> # Insert objects described by text at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENPipeline.from_pretrained( + ... "masterful/gligen-1-4-inpainting-text-box", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> input_image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png" + ... ) + >>> prompt = "a birthday cake" + >>> boxes = [[0.2676, 0.6088, 0.4773, 0.7183]] + >>> phrases = ["a birthday cake"] + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_inpaint_image=input_image, + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-1-4-inpainting-text-box.jpg") + + >>> # Generate an image described by the prompt and + >>> # insert objects described by text at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENPipeline.from_pretrained( + ... "masterful/gligen-1-4-generation-text-box", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a waterfall and a modern high speed train running through the tunnel in a beautiful forest with fall foliage" + >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]] + >>> phrases = ["a waterfall", "a modern high speed train running through the tunnel"] + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-1-4-generation-text-box.jpg") + ``` +""" + + +class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + model_cpu_offload_seq = "text_encoder->unet->vae" + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + gligen_phrases, + gligen_boxes, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if len(gligen_phrases) != len(gligen_boxes): + raise ValueError( + "length of `gligen_phrases` and `gligen_boxes` has to be same, but" + f" got: `gligen_phrases` {len(gligen_phrases)} != `gligen_boxes` {len(gligen_boxes)}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_fuser(self, enabled=True): + for module in self.unet.modules(): + if type(module) is GatedSelfAttentionDense: + module.enabled = enabled + + def draw_inpaint_mask_from_boxes(self, boxes, size): + inpaint_mask = torch.ones(size[0], size[1]) + for box in boxes: + x0, x1 = box[0] * size[0], box[2] * size[0] + y0, y1 = box[1] * size[1], box[3] * size[1] + inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0 + return inpaint_mask + + def crop(self, im, new_width, new_height): + width, height = im.size + left = (width - new_width) / 2 + top = (height - new_height) / 2 + right = (width + new_width) / 2 + bottom = (height + new_height) / 2 + return im.crop((left, top, right, bottom)) + + def target_size_center_crop(self, im, new_hw): + width, height = im.size + if width != height: + im = self.crop(im, min(height, width), min(height, width)) + return im.resize((new_hw, new_hw), PIL.Image.LANCZOS) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + gligen_scheduled_sampling_beta: float = 0.3, + gligen_phrases: List[str] = None, + gligen_boxes: List[List[float]] = None, + gligen_inpaint_image: Optional[PIL.Image.Image] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + gligen_phrases (`List[str]`): + The phrases to guide what to include in each of the regions defined by the corresponding + `gligen_boxes`. There should only be one phrase per bounding box. + gligen_boxes (`List[List[float]]`): + The bounding boxes that identify rectangular regions of the image that are going to be filled with the + content described by the corresponding `gligen_phrases`. Each rectangular box is defined as a + `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1]. + gligen_inpaint_image (`PIL.Image.Image`, *optional*): + The input image, if provided, is inpainted with objects described by the `gligen_boxes` and + `gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image. + gligen_scheduled_sampling_beta (`float`, defaults to 0.3): + Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image + Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for + scheduled sampling during inference for improved quality and controllability. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + gligen_phrases, + gligen_boxes, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5.1 Prepare GLIGEN variables + max_objs = 30 + if len(gligen_boxes) > max_objs: + warnings.warn( + f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.", + FutureWarning, + ) + gligen_phrases = gligen_phrases[:max_objs] + gligen_boxes = gligen_boxes[:max_objs] + # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask) + # Get tokens for phrases from pre-trained CLIPTokenizer + tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device) + # For the token, we use the same pre-trained text encoder + # to obtain its text feature + _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output + n_objs = len(gligen_boxes) + # For each entity, described in phrases, is denoted with a bounding box, + # we represent the location information as (xmin,ymin,xmax,ymax) + boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) + boxes[:n_objs] = torch.tensor(gligen_boxes) + text_embeddings = torch.zeros( + max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype + ) + text_embeddings[:n_objs] = _text_embeddings + # Generate a mask for each object that is entity described by phrases + masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + masks[:n_objs] = 1 + + repeat_batch = batch_size * num_images_per_prompt + boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone() + text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() + masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone() + if do_classifier_free_guidance: + repeat_batch = repeat_batch * 2 + boxes = torch.cat([boxes] * 2) + text_embeddings = torch.cat([text_embeddings] * 2) + masks = torch.cat([masks] * 2) + masks[: repeat_batch // 2] = 0 + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + cross_attention_kwargs["gligen"] = {"boxes": boxes, "positive_embeddings": text_embeddings, "masks": masks} + + # Prepare latent variables for GLIGEN inpainting + if gligen_inpaint_image is not None: + # if the given input image is not of the same size as expected by VAE + # center crop and resize the input image to expected shape + if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size): + gligen_inpaint_image = self.target_size_center_crop(gligen_inpaint_image, self.vae.sample_size) + # Convert a single image into a batch of images with a batch size of 1 + # The resulting shape becomes (1, C, H, W), where C is the number of channels, + # and H and W are the height and width of the image. + # scales the pixel values to a range [-1, 1] + gligen_inpaint_image = self.image_processor.preprocess(gligen_inpaint_image) + gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device) + # Run AutoEncoder to get corresponding latents + gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample() + gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent + # Generate an inpainting mask + # pixel value = 0, where the object is present (defined by bounding boxes above) + # 1, everywhere else + gligen_inpaint_mask = self.draw_inpaint_mask_from_boxes(gligen_boxes, gligen_inpaint_latent.shape[2:]) + gligen_inpaint_mask = gligen_inpaint_mask.to( + dtype=gligen_inpaint_latent.dtype, device=gligen_inpaint_latent.device + ) + gligen_inpaint_mask = gligen_inpaint_mask[None, None] + gligen_inpaint_mask_addition = torch.cat( + (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1 + ) + # Convert a single mask into a batch of masks with a batch size of 1 + gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone() + + num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps)) + self.enable_fuser(True) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Scheduled sampling + if i == num_grounding_steps: + self.enable_fuser(False) + + if latents.shape[1] != 4: + latents = torch.randn_like(latents[:, :4]) + + if gligen_inpaint_image is not None: + gligen_inpaint_latent_with_noise = ( + self.scheduler.add_noise( + gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t]) + ) + .expand(latents.shape[0], -1, -1, -1) + .clone() + ) + latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * ( + 1 - gligen_inpaint_mask + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if gligen_inpaint_image is not None: + latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py new file mode 100755 index 0000000..c6748ad --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -0,0 +1,1023 @@ +# Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.clip_image_project_model import CLIPImageProjection +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionGLIGENTextImagePipeline + >>> from diffusers.utils import load_image + + >>> # Insert objects described by image at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( + ... "anhnct/Gligen_Inpainting_Text_Image", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> input_image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png" + ... ) + >>> prompt = "a backpack" + >>> boxes = [[0.2676, 0.4088, 0.4773, 0.7183]] + >>> phrases = None + >>> gligen_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/backpack.jpeg" + ... ) + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_inpaint_image=input_image, + ... gligen_boxes=boxes, + ... gligen_images=[gligen_image], + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-inpainting-text-image-box.jpg") + + >>> # Generate an image described by the prompt and + >>> # insert objects described by text and image at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( + ... "anhnct/Gligen_Text_Image", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a flower sitting on the beach" + >>> boxes = [[0.0, 0.09, 0.53, 0.76]] + >>> phrases = ["flower"] + >>> gligen_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/pexels-pixabay-60597.jpg" + ... ) + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_images=[gligen_image], + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-generation-text-image-box.jpg") + + >>> # Generate an image described by the prompt and + >>> # transfer style described by image at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( + ... "anhnct/Gligen_Text_Image", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a dragon flying on the sky" + >>> boxes = [[0.4, 0.2, 1.0, 0.8], [0.0, 1.0, 0.0, 1.0]] # Set `[0.0, 1.0, 0.0, 1.0]` for the style + + >>> gligen_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png" + ... ) + + >>> gligen_placeholder = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png" + ... ) + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=[ + ... "dragon", + ... "placeholder", + ... ], # Can use any text instead of `placeholder` token, because we will use mask here + ... gligen_images=[ + ... gligen_placeholder, + ... gligen_image, + ... ], # Can use any image in gligen_placeholder, because we will use mask here + ... input_phrases_mask=[1, 0], # Set 0 for the placeholder token + ... input_images_mask=[0, 1], # Set 0 for the placeholder image + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-generation-text-image-box-style-transfer.jpg") + ``` +""" + + +class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + processor ([`~transformers.CLIPProcessor`]): + A `CLIPProcessor` to procces reference image. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + image_project ([`CLIPImageProjection`]): + A `CLIPImageProjection` to project image embedding into phrases embedding space. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + processor: CLIPProcessor, + image_encoder: CLIPVisionModelWithProjection, + image_project: CLIPImageProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + processor=processor, + image_project=image_project, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_fuser(self, enabled=True): + for module in self.unet.modules(): + if type(module) is GatedSelfAttentionDense: + module.enabled = enabled + + def draw_inpaint_mask_from_boxes(self, boxes, size): + """ + Create an inpainting mask based on given boxes. This function generates an inpainting mask using the provided + boxes to mark regions that need to be inpainted. + """ + inpaint_mask = torch.ones(size[0], size[1]) + for box in boxes: + x0, x1 = box[0] * size[0], box[2] * size[0] + y0, y1 = box[1] * size[1], box[3] * size[1] + inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0 + return inpaint_mask + + def crop(self, im, new_width, new_height): + """ + Crop the input image to the specified dimensions. + """ + width, height = im.size + left = (width - new_width) / 2 + top = (height - new_height) / 2 + right = (width + new_width) / 2 + bottom = (height + new_height) / 2 + return im.crop((left, top, right, bottom)) + + def target_size_center_crop(self, im, new_hw): + """ + Crop and resize the image to the target size while keeping the center. + """ + width, height = im.size + if width != height: + im = self.crop(im, min(height, width), min(height, width)) + return im.resize((new_hw, new_hw), PIL.Image.LANCZOS) + + def complete_mask(self, has_mask, max_objs, device): + """ + Based on the input mask corresponding value `0 or 1` for each phrases and image, mask the features + corresponding to phrases and images. + """ + mask = torch.ones(1, max_objs).type(self.text_encoder.dtype).to(device) + if has_mask is None: + return mask + + if isinstance(has_mask, int): + return mask * has_mask + else: + for idx, value in enumerate(has_mask): + mask[0, idx] = value + return mask + + def get_clip_feature(self, input, normalize_constant, device, is_image=False): + """ + Get image and phrases embedding by using CLIP pretrain model. The image embedding is transformed into the + phrases embedding space through a projection. + """ + if is_image: + if input is None: + return None + inputs = self.processor(images=[input], return_tensors="pt").to(device) + inputs["pixel_values"] = inputs["pixel_values"].to(self.image_encoder.dtype) + + outputs = self.image_encoder(**inputs) + feature = outputs.image_embeds + feature = self.image_project(feature).squeeze(0) + feature = (feature / feature.norm()) * normalize_constant + feature = feature.unsqueeze(0) + else: + if input is None: + return None + inputs = self.tokenizer(input, return_tensors="pt", padding=True).to(device) + outputs = self.text_encoder(**inputs) + feature = outputs.pooler_output + return feature + + def get_cross_attention_kwargs_with_grounded( + self, + hidden_size, + gligen_phrases, + gligen_images, + gligen_boxes, + input_phrases_mask, + input_images_mask, + repeat_batch, + normalize_constant, + max_objs, + device, + ): + """ + Prepare the cross-attention kwargs containing information about the grounded input (boxes, mask, image + embedding, phrases embedding). + """ + phrases, images = gligen_phrases, gligen_images + images = [None] * len(phrases) if images is None else images + phrases = [None] * len(images) if phrases is None else phrases + + boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) + masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + image_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + image_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + + text_features = [] + image_features = [] + for phrase, image in zip(phrases, images): + text_features.append(self.get_clip_feature(phrase, normalize_constant, device, is_image=False)) + image_features.append(self.get_clip_feature(image, normalize_constant, device, is_image=True)) + + for idx, (box, text_feature, image_feature) in enumerate(zip(gligen_boxes, text_features, image_features)): + boxes[idx] = torch.tensor(box) + masks[idx] = 1 + if text_feature is not None: + phrases_embeddings[idx] = text_feature + phrases_masks[idx] = 1 + if image_feature is not None: + image_embeddings[idx] = image_feature + image_masks[idx] = 1 + + input_phrases_mask = self.complete_mask(input_phrases_mask, max_objs, device) + phrases_masks = phrases_masks.unsqueeze(0).repeat(repeat_batch, 1) * input_phrases_mask + input_images_mask = self.complete_mask(input_images_mask, max_objs, device) + image_masks = image_masks.unsqueeze(0).repeat(repeat_batch, 1) * input_images_mask + boxes = boxes.unsqueeze(0).repeat(repeat_batch, 1, 1) + masks = masks.unsqueeze(0).repeat(repeat_batch, 1) + phrases_embeddings = phrases_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1) + image_embeddings = image_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1) + + out = { + "boxes": boxes, + "masks": masks, + "phrases_masks": phrases_masks, + "image_masks": image_masks, + "phrases_embeddings": phrases_embeddings, + "image_embeddings": image_embeddings, + } + + return out + + def get_cross_attention_kwargs_without_grounded(self, hidden_size, repeat_batch, max_objs, device): + """ + Prepare the cross-attention kwargs without information about the grounded input (boxes, mask, image embedding, + phrases embedding) (All are zero tensor). + """ + boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) + masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + image_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + image_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + + out = { + "boxes": boxes.unsqueeze(0).repeat(repeat_batch, 1, 1), + "masks": masks.unsqueeze(0).repeat(repeat_batch, 1), + "phrases_masks": phrases_masks.unsqueeze(0).repeat(repeat_batch, 1), + "image_masks": image_masks.unsqueeze(0).repeat(repeat_batch, 1), + "phrases_embeddings": phrases_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1), + "image_embeddings": image_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1), + } + + return out + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + gligen_scheduled_sampling_beta: float = 0.3, + gligen_phrases: List[str] = None, + gligen_images: List[PIL.Image.Image] = None, + input_phrases_mask: Union[int, List[int]] = None, + input_images_mask: Union[int, List[int]] = None, + gligen_boxes: List[List[float]] = None, + gligen_inpaint_image: Optional[PIL.Image.Image] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + gligen_normalize_constant: float = 28.7, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + gligen_phrases (`List[str]`): + The phrases to guide what to include in each of the regions defined by the corresponding + `gligen_boxes`. There should only be one phrase per bounding box. + gligen_images (`List[PIL.Image.Image]`): + The images to guide what to include in each of the regions defined by the corresponding `gligen_boxes`. + There should only be one image per bounding box + input_phrases_mask (`int` or `List[int]`): + pre phrases mask input defined by the correspongding `input_phrases_mask` + input_images_mask (`int` or `List[int]`): + pre images mask input defined by the correspongding `input_images_mask` + gligen_boxes (`List[List[float]]`): + The bounding boxes that identify rectangular regions of the image that are going to be filled with the + content described by the corresponding `gligen_phrases`. Each rectangular box is defined as a + `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1]. + gligen_inpaint_image (`PIL.Image.Image`, *optional*): + The input image, if provided, is inpainted with objects described by the `gligen_boxes` and + `gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image. + gligen_scheduled_sampling_beta (`float`, defaults to 0.3): + Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image + Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for + scheduled sampling during inference for improved quality and controllability. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + gligen_normalize_constant (`float`, *optional*, defaults to 28.7): + The normalize value of the image embedding. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5.1 Prepare GLIGEN variables + max_objs = 30 + if len(gligen_boxes) > max_objs: + warnings.warn( + f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.", + FutureWarning, + ) + gligen_phrases = gligen_phrases[:max_objs] + gligen_boxes = gligen_boxes[:max_objs] + gligen_images = gligen_images[:max_objs] + + repeat_batch = batch_size * num_images_per_prompt + + if do_classifier_free_guidance: + repeat_batch = repeat_batch * 2 + + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + hidden_size = prompt_embeds.shape[2] + + cross_attention_kwargs["gligen"] = self.get_cross_attention_kwargs_with_grounded( + hidden_size=hidden_size, + gligen_phrases=gligen_phrases, + gligen_images=gligen_images, + gligen_boxes=gligen_boxes, + input_phrases_mask=input_phrases_mask, + input_images_mask=input_images_mask, + repeat_batch=repeat_batch, + normalize_constant=gligen_normalize_constant, + max_objs=max_objs, + device=device, + ) + + cross_attention_kwargs_without_grounded = {} + cross_attention_kwargs_without_grounded["gligen"] = self.get_cross_attention_kwargs_without_grounded( + hidden_size=hidden_size, repeat_batch=repeat_batch, max_objs=max_objs, device=device + ) + + # Prepare latent variables for GLIGEN inpainting + if gligen_inpaint_image is not None: + # if the given input image is not of the same size as expected by VAE + # center crop and resize the input image to expected shape + if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size): + gligen_inpaint_image = self.target_size_center_crop(gligen_inpaint_image, self.vae.sample_size) + # Convert a single image into a batch of images with a batch size of 1 + # The resulting shape becomes (1, C, H, W), where C is the number of channels, + # and H and W are the height and width of the image. + # scales the pixel values to a range [-1, 1] + gligen_inpaint_image = self.image_processor.preprocess(gligen_inpaint_image) + gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device) + # Run AutoEncoder to get corresponding latents + gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample() + gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent + # Generate an inpainting mask + # pixel value = 0, where the object is present (defined by bounding boxes above) + # 1, everywhere else + gligen_inpaint_mask = self.draw_inpaint_mask_from_boxes(gligen_boxes, gligen_inpaint_latent.shape[2:]) + gligen_inpaint_mask = gligen_inpaint_mask.to( + dtype=gligen_inpaint_latent.dtype, device=gligen_inpaint_latent.device + ) + gligen_inpaint_mask = gligen_inpaint_mask[None, None] + gligen_inpaint_mask_addition = torch.cat( + (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1 + ) + # Convert a single mask into a batch of masks with a batch size of 1 + gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone() + + int(gligen_scheduled_sampling_beta * len(timesteps)) + self.enable_fuser(True) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if latents.shape[1] != 4: + latents = torch.randn_like(latents[:, :4]) + + if gligen_inpaint_image is not None: + gligen_inpaint_latent_with_noise = ( + self.scheduler.add_noise( + gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t]) + ) + .expand(latents.shape[0], -1, -1, -1) + .clone() + ) + latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * ( + 1 - gligen_inpaint_mask + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if gligen_inpaint_image is not None: + latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1) + + # predict the noise residual with grounded information + noise_pred_with_grounding = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # predict the noise residual without grounded information + noise_pred_without_grounding = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs_without_grounded, + ).sample + + # perform guidance + if do_classifier_free_guidance: + # Using noise_pred_text from noise residual with grounded information and noise_pred_uncond from noise residual without grounded information + _, noise_pred_text = noise_pred_with_grounding.chunk(2) + noise_pred_uncond, _ = noise_pred_without_grounding.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_with_grounding + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py new file mode 100755 index 0000000..7eb5bf8 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_k_diffusion_available, + is_k_diffusion_version, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not ( + is_transformers_available() + and is_torch_available() + and is_k_diffusion_available() + and is_k_diffusion_version(">=", "0.0.12") + ): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) +else: + _import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion_xl_k_diffusion"] = ["StableDiffusionXLKDiffusionPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not ( + is_transformers_available() + and is_torch_available() + and is_k_diffusion_available() + and is_k_diffusion_version(">=", "0.0.12") + ): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * + else: + from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline + from .pipeline_stable_diffusion_xl_k_diffusion import StableDiffusionXLKDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py new file mode 100755 index 0000000..122701f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -0,0 +1,670 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +from typing import Callable, List, Optional, Union + +import torch +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import LMSDiscreteScheduler +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ModelWrapper: + def __init__(self, model, alphas_cumprod): + self.model = model + self.alphas_cumprod = alphas_cumprod + + def apply_model(self, *args, **kwargs): + if len(args) == 3: + encoder_hidden_states = args[-1] + args = args[:2] + if kwargs.get("cond", None) is not None: + encoder_hidden_states = kwargs.pop("cond") + return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample + + +class StableDiffusionKDiffusionPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + + + This is an experimental pipeline and is likely to change in the future. + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + logger.info( + f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use" + " this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines" + " as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for" + " production settings." + ) + + # get correct sigmas from LMS + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + model = ModelWrapper(unet, scheduler.alphas_cumprod) + if scheduler.config.prediction_type == "v_prediction": + self.k_diffusion_model = CompVisVDenoiser(model) + else: + self.k_diffusion_model = CompVisDenoiser(model) + + def set_scheduler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + try: + self.sampler = getattr(sampling, scheduler_type) + except Exception: + valid_samplers = [] + for s in dir(sampling): + if "sample_" in s: + valid_samplers.append(s) + + raise ValueError(f"Invalid scheduler type {scheduler_type}. Please choose one of {valid_samplers}.") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, + clip_skip: int = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to + `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M + Karras`. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed will be generated. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = True + if guidance_scale <= 1.0: + raise ValueError("has to use guidance_scale") + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + else: + sigmas = self.scheduler.sigmas + sigmas = sigmas.to(device) + sigmas = sigmas.to(prompt_embeds.dtype) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents = latents * sigmas[0] + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + # 7. Define model function + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + t = torch.cat([t] * 2) + + noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + # 8. Run k-diffusion solver + sampler_kwargs = {} + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + if "generator" in inspect.signature(self.sampler).parameters: + sampler_kwargs["generator"] = generator + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py new file mode 100755 index 0000000..45f814f --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py @@ -0,0 +1,892 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras +from transformers import ( + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLKDiffusionPipeline + + >>> pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> pipe.set_scheduler("sample_dpmpp_2m_sde") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.ModelWrapper +class ModelWrapper: + def __init__(self, model, alphas_cumprod): + self.model = model + self.alphas_cumprod = alphas_cumprod + + def apply_model(self, *args, **kwargs): + if len(args) == 3: + encoder_hidden_states = args[-1] + args = args[:2] + if kwargs.get("cond", None) is not None: + encoder_hidden_states = kwargs.pop("cond") + return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample + + +class StableDiffusionXLKDiffusionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL and k-diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + # get correct sigmas from LMS + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + model = ModelWrapper(unet, scheduler.alphas_cumprod) + if scheduler.config.prediction_type == "v_prediction": + self.k_diffusion_model = CompVisVDenoiser(model) + else: + self.k_diffusion_model = CompVisDenoiser(model) + + # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.set_scheduler + def set_scheduler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + try: + self.sampler = getattr(sampling, scheduler_type) + except Exception: + valid_samplers = [] + for s in dir(sampling): + if "sample_" in s: + valid_samplers.append(s) + + raise ValueError(f"Invalid scheduler type {scheduler_type}. Please choose one of {valid_samplers}.") + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, + clip_skip: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + if guidance_scale <= 1.0: + raise ValueError("has to use guidance_scale") + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = None + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + else: + sigmas = self.scheduler.sigmas + sigmas = sigmas.to(dtype=prompt_embeds.dtype, device=device) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents = latents * sigmas[0] + + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # 8. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 9. Define model function + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + t = torch.cat([t] * 2) + + noise_pred = self.k_diffusion_model( + latent_model_input, + t, + cond=prompt_embeds, + timestep_cond=timestep_cond, + added_cond_kwargs=added_cond_kwargs, + ) + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + # 10. Run k-diffusion solver + sampler_kwargs = {} + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + if "generator" in inspect.signature(self.sampler).parameters: + sampler_kwargs["generator"] = generator + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py new file mode 100755 index 0000000..dae2aff --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/diffusers/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py new file mode 100755 index 0000000..251ec12 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -0,0 +1,1013 @@ +# Copyright 2024 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessorLDM3D +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> from diffusers import StableDiffusionLDM3DPipeline + + >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c") + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> output = pipe(prompt) + >>> rgb_image, depth_image = output.rgb, output.depth + >>> rgb_image[0].save("astronaut_ldm3d_rgb.jpg") + >>> depth_image[0].save("astronaut_ldm3d_depth.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class LDM3DPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + rgb (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + depth (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`List[bool]`) + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + rgb: Union[List[PIL.Image.Image], np.ndarray] + depth: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +class StableDiffusionLDM3DPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image and 3D generation using LDM3D. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + rgb_feature_extractor_input = feature_extractor_input[0] + safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 49, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return ((rgb, depth), has_nsfw_concept) + + return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_panorama/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_panorama/__init__.py new file mode 100755 index 0000000..f7572db --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_panorama/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/diffusers/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py new file mode 100755 index 0000000..96fba06 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -0,0 +1,1168 @@ +# Copyright 2024 MultiDiffusion Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDIMScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler + + >>> model_ckpt = "stabilityai/stable-diffusion-2-base" + >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") + >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained( + ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16 + ... ) + + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of the dolomites" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPanoramaPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using MultiDiffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -> torch.Tensor: + """ + Decode the given latents with padding for circular inference. + + Args: + latents (torch.Tensor): The input latents to decode. + padding (int, optional): The number of latents to add on each side for padding. Defaults to 8. + + Returns: + torch.Tensor: The decoded image with padding removed. + + Notes: + - The padding is added to remove boundary artifacts and improve the output quality. + - This would slightly increase the memory usage. + - The padding pixels are then removed from the decoded image. + + """ + latents = 1 / self.vae.config.scaling_factor * latents + latents_left = latents[..., :padding] + latents_right = latents[..., -padding:] + latents = torch.cat((latents_right, latents, latents_left), axis=-1) + image = self.vae.decode(latents, return_dict=False)[0] + padding_pix = self.vae_scale_factor * padding + image = image[..., padding_pix:-padding_pix] + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + def get_views( + self, + panorama_height: int, + panorama_width: int, + window_size: int = 64, + stride: int = 8, + circular_padding: bool = False, + ) -> List[Tuple[int, int, int, int]]: + """ + Generates a list of views based on the given parameters. Here, we define the mappings F_i (see Eq. 7 in the + MultiDiffusion paper https://arxiv.org/abs/2302.08113). If panorama's height/width < window_size, num_blocks of + height/width should return 1. + + Args: + panorama_height (int): The height of the panorama. + panorama_width (int): The width of the panorama. + window_size (int, optional): The size of the window. Defaults to 64. + stride (int, optional): The stride value. Defaults to 8. + circular_padding (bool, optional): Whether to apply circular padding. Defaults to False. + + Returns: + List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains four integers + representing the start and end coordinates of the window in the panorama. + + """ + panorama_height /= 8 + panorama_width /= 8 + num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 + if circular_padding: + num_blocks_width = panorama_width // stride if panorama_width > window_size else 1 + else: + num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + views.append((h_start, h_end, w_start, w_end)) + return views + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + @property + def do_classifier_free_guidance(self): + return False + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 2048, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + view_batch_size: int = 1, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + circular_padding: bool = False, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs: Any, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 2048): + The width in pixels of the generated image. The width is kept high because the pipeline is supposed + generate panorama-like images. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + The timesteps at which to generate the images. If not specified, then the default timestep spacing + strategy of the scheduler is used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + view_batch_size (`int`, *optional*, defaults to 1): + The batch size to denoise split views. For some GPUs with high performance, higher view batch size can + speedup the generation and increase the VRAM usage. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescaling factor for the guidance embeddings. A value of 0.0 means no rescaling is applied. + circular_padding (`bool`, *optional*, defaults to `False`): + If set to `True`, circular padding is applied to ensure there are no stitching artifacts. Circular + padding allows the model to seamlessly generate a transition from the rightmost part of the image to + the leftmost part, maintaining consistency in a 360-degree sense. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Define panorama grid and initialize views for synthesis. + # prepare batch grid + views = self.get_views(height, width, circular_padding=circular_padding) + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) + count = torch.zeros_like(latents) + value = torch.zeros_like(latents) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + # Each denoising step also includes refinement of the latents with respect to the + # views. + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + count.zero_() + value.zero_() + + # generate views + # Here, we iterate through different spatial crops of the latents and denoise them. These + # denoised (latent) crops are then averaged to produce the final latent + # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the + # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 + # Batch views denoise + for j, batch_view in enumerate(views_batch): + vb_size = len(batch_view) + # get the latents corresponding to the current view coordinates + if circular_padding: + latents_for_view = [] + for h_start, h_end, w_start, w_end in batch_view: + if w_end > latents.shape[3]: + # Add circular horizontal padding + latent_view = torch.cat( + ( + latents[:, :, h_start:h_end, w_start:], + latents[:, :, h_start:h_end, : w_end - latents.shape[3]], + ), + axis=-1, + ) + else: + latent_view = latents[:, :, h_start:h_end, w_start:w_end] + latents_for_view.append(latent_view) + latents_for_view = torch.cat(latents_for_view) + else: + latents_for_view = torch.cat( + [ + latents[:, :, h_start:h_end, w_start:w_end] + for h_start, h_end, w_start, w_end in batch_view + ] + ) + + # rematch block's scheduler status + self.scheduler.__dict__.update(views_scheduler_status[j]) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + latents_for_view.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latents_for_view + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # repeat prompt_embeds for batch + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds_input, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_denoised_batch = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs + ).prev_sample + + # save views scheduler status after sample + views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) + + # extract value from batch + for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + if circular_padding and w_end > latents.shape[3]: + # Case for circular padding + value[:, :, h_start:h_end, w_start:] += latents_view_denoised[ + :, :, h_start:h_end, : latents.shape[3] - w_start + ] + value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[ + :, :, h_start:h_end, latents.shape[3] - w_start : + ] + count[:, :, h_start:h_end, w_start:] += 1 + count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1 + else: + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 + + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 + latents = torch.where(count > 0, value / count, value) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if output_type != "latent": + if circular_padding: + image = self.decode_latents_with_padding(latents) + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__init__.py new file mode 100755 index 0000000..b432b94 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + BaseOutput, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +@dataclass +class SafetyConfig(object): + WEAK = { + "sld_warmup_steps": 15, + "sld_guidance_scale": 20, + "sld_threshold": 0.0, + "sld_momentum_scale": 0.0, + "sld_mom_beta": 0.0, + } + MEDIUM = { + "sld_warmup_steps": 10, + "sld_guidance_scale": 1000, + "sld_threshold": 0.01, + "sld_momentum_scale": 0.3, + "sld_mom_beta": 0.4, + } + STRONG = { + "sld_warmup_steps": 7, + "sld_guidance_scale": 2000, + "sld_threshold": 0.025, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + MAX = { + "sld_warmup_steps": 0, + "sld_guidance_scale": 5000, + "sld_threshold": 1.0, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {} + +_additional_imports.update({"SafetyConfig": SafetyConfig}) + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_output": ["StableDiffusionSafePipelineOutput"], + "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], + "safety_checker": ["StableDiffusionSafetyChecker"], + } + ) + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import StableDiffusionSafePipelineOutput + from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe + from .safety_checker import SafeStableDiffusionSafetyChecker + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py new file mode 100755 index 0000000..69a064d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import ( + BaseOutput, +) + + +@dataclass +class StableDiffusionSafePipelineOutput(BaseOutput): + """ + Output class for Safe Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" + (nsfw) content, or `None` if no safety check was performed or no images were flagged. + applied_safety_concept (`str`) + The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + applied_safety_concept: Optional[str] diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py new file mode 100755 index 0000000..cd59cf5 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -0,0 +1,769 @@ +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionSafePipelineOutput +from .safety_checker import SafeStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin): + r""" + Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: SafeStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + safety_concept: Optional[str] = ( + "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," + " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" + " abuse, brutality, cruelty" + ) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self._safety_text_concept = safety_concept + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @property + def safety_concept(self): + r""" + Getter method for the safety concept used with SLD + + Returns: + `str`: The text describing the safety concept + """ + return self._safety_text_concept + + @safety_concept.setter + def safety_concept(self, concept): + r""" + Setter method for the safety concept used with SLD + + Args: + concept (`str`): + The text of the new safety concept + """ + self._safety_text_concept = concept + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + enable_safety_guidance, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # Encode the safety concept text + if enable_safety_guidance: + safety_concept_input = self.tokenizer( + [self._safety_text_concept], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] + + # duplicate safety embeddings for each generation per prompt, using mps friendly method + seq_len = safety_embeddings.shape[1] + safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) + safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance + sld, we need to do three forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing three forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) + + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype, enable_safety_guidance): + if self.safety_checker is not None: + images = image.copy() + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + flagged_images = np.zeros((2, *image.shape[1:])) + if any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead." + f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" + ) + for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + if has_nsfw_concept: + flagged_images[idx] = images[idx] + image[idx] = np.zeros(image[idx].shape) # black image + else: + has_nsfw_concept = None + flagged_images = None + return image, has_nsfw_concept, flagged_images + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def perform_safety_guidance( + self, + enable_safety_guidance, + safety_momentum, + noise_guidance, + noise_pred_out, + i, + sld_guidance_scale, + sld_warmup_steps, + sld_threshold, + sld_momentum_scale, + sld_mom_beta, + ): + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale + ) + + # Equation 4 + noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + return noise_guidance, safety_momentum + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + sld_guidance_scale: Optional[float] = 1000, + sld_warmup_steps: Optional[int] = 10, + sld_threshold: Optional[float] = 0.01, + sld_momentum_scale: Optional[float] = 0.3, + sld_mom_beta: Optional[float] = 0.4, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + sld_guidance_scale (`float`, *optional*, defaults to 1000): + If `sld_guidance_scale < 1`, safety guidance is disabled. + sld_warmup_steps (`int`, *optional*, defaults to 10): + Number of warmup steps for safety guidance. SLD is only be applied for diffusion steps greater than + `sld_warmup_steps`. + sld_threshold (`float`, *optional*, defaults to 0.01): + Threshold that separates the hyperplane between appropriate and inappropriate images. + sld_momentum_scale (`float`, *optional*, defaults to 0.3): + Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0, + momentum is disabled. Momentum is built up during warmup for diffusion steps smaller than + `sld_warmup_steps`. + sld_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous + momentum is kept. Momentum is built up during warmup for diffusion steps smaller than + `sld_warmup_steps`. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + import torch + from diffusers import StableDiffusionPipelineSafe + from diffusers.pipelines.stable_diffusion_safe import SafetyConfig + + pipeline = StableDiffusionPipelineSafe.from_pretrained( + "AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16 + ).to("cuda") + prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker" + image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0] + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance + if not enable_safety_guidance: + warnings.warn("Safety checker disabled!") + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + if enable_safety_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds, image_embeds]) + else: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + safety_momentum = None + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (3 if enable_safety_guidance else 2)) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + + # default classifier free guidance + noise_guidance = noise_pred_text - noise_pred_uncond + + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp( + torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 + ) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, + torch.zeros_like(scale), + scale, + ) + + # Equation 4 + noise_guidance_safety = torch.mul( + (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale + ) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + + noise_pred = noise_pred_uncond + guidance_scale * noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept, flagged_images = self.run_safety_checker( + image, device, prompt_embeds.dtype, enable_safety_guidance + ) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + if flagged_images is not None: + flagged_images = self.numpy_to_pil(flagged_images) + + if not return_dict: + return ( + image, + has_nsfw_concept, + self._safety_text_concept if enable_safety_guidance else None, + flagged_images, + ) + + return StableDiffusionSafePipelineOutput( + images=image, + nsfw_content_detected=has_nsfw_concept, + applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, + unsafe_images=flagged_images, + ) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py new file mode 100755 index 0000000..338e4c6 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py @@ -0,0 +1,109 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class SafeStableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + return images, has_nsfw_concepts diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_sag/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_sag/__init__.py new file mode 100755 index 0000000..378e0e5 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_sag/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/diffusers/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py new file mode 100755 index 0000000..c32052d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -0,0 +1,954 @@ +# Copyright 2024 Susung Hong and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionSAGPipeline + + >>> pipe = StableDiffusionSAGPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, sag_scale=0.75).images[0] + ``` +""" + + +# processes and stores attention probabilities +class CrossAttnStoreProcessor: + def __init__(self): + self.attention_probs = None + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + self.attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(self.attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input +class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + sag_scale: float = 0.75, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + sag_scale (`float`, *optional*, defaults to 0.75): + Chosen between [0, 1.0] for better quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the + `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # and `sag_scale` is` `s` of equation (16) + # of the self-attention guidance paper: https://arxiv.org/pdf/2210.00939.pdf + # `sag_scale = 0` means no self-attention guidance + do_self_attention_guidance = sag_scale > 0.0 + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + do_classifier_free_guidance, + ) + + if do_classifier_free_guidance: + image_embeds = [] + negative_image_embeds = [] + for tmp_image_embeds in ip_adapter_image_embeds: + single_negative_image_embeds, single_image_embeds = tmp_image_embeds.chunk(2) + image_embeds.append(single_image_embeds) + negative_image_embeds.append(single_negative_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if timesteps.dtype not in [torch.int16, torch.int32, torch.int64]: + raise ValueError( + f"{self.__class__.__name__} does not support using a scheduler of type {self.scheduler.__class__.__name__}. Please make sure to use one of 'DDIMScheduler, PNDMScheduler, DDPMScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler'." + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + if do_classifier_free_guidance: + added_uncond_kwargs = ( + {"image_embeds": negative_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7. Denoising loop + original_attn_proc = self.unet.attn_processors + store_processor = CrossAttnStoreProcessor() + self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + map_size = None + + def get_map_size(module, input, output): + nonlocal map_size + map_size = output[0].shape[-2:] + + with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # perform self-attention guidance with the stored self-attention map + if do_self_attention_guidance: + # classifier-free guidance produces two chunks of attention map + # and we only use unconditional one according to equation (25) + # in https://arxiv.org/pdf/2210.00939.pdf + if do_classifier_free_guidance: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) + # get the stored attention maps + uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t) + ) + uncond_emb, _ = prompt_embeds.chunk(2) + # forward and give guidance + degraded_pred = self.unet( + degraded_latents, + t, + encoder_hidden_states=uncond_emb, + added_cond_kwargs=added_uncond_kwargs, + ).sample + noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) + else: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred, t) + # get the stored attention maps + cond_attn = store_processor.attention_probs + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) + ) + # forward and give guidance + degraded_pred = self.unet( + degraded_latents, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + noise_pred += sag_scale * (noise_pred - degraded_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def sag_masking(self, original_latents, attn_map, map_size, t, eps): + # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf + bh, hw1, hw2 = attn_map.shape + b, latent_channel, latent_h, latent_w = original_latents.shape + h = self.unet.config.attention_head_dim + if isinstance(h, list): + h = h[-1] + + # Produce attention mask + attn_map = attn_map.reshape(b, h, hw1, hw2) + attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 + attn_mask = ( + attn_mask.reshape(b, map_size[0], map_size[1]) + .unsqueeze(1) + .repeat(1, latent_channel, 1, 1) + .type(attn_map.dtype) + ) + attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) + + # Blur according to the self-attention mask + degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) + degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) + + # Noise it again to match the noise level + degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t[None]) + + return degraded_latents + + # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step + # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) + def pred_x0(self, sample, model_output, timestep): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device) + + beta_prod_t = 1 - alpha_prod_t + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_original_sample + + def pred_epsilon(self, sample, model_output, timestep): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + if self.scheduler.config.prediction_type == "epsilon": + pred_eps = model_output + elif self.scheduler.config.prediction_type == "sample": + pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) + elif self.scheduler.config.prediction_type == "v_prediction": + pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_eps + + +# Gaussian blur +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/__init__.py new file mode 100755 index 0000000..8088fbc --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,76 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]} + +if is_transformers_available() and is_flax_available(): + _import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"]) +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_xl"] = ["StableDiffusionXLPipeline"] + _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] + _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] + _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] + +if is_transformers_available() and is_flax_available(): + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + + _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) + _import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline + from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline + from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline + from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_objects import * + else: + from .pipeline_flax_stable_diffusion_xl import ( + FlaxStableDiffusionXLPipeline, + ) + from .pipeline_output import FlaxStableDiffusionXLPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py new file mode 100755 index 0000000..77363b2 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -0,0 +1,308 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPTokenizer, FlaxCLIPTextModel + +from diffusers.utils import logging + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionXLPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + + +class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): + def __init__( + self, + text_encoder: FlaxCLIPTextModel, + text_encoder_2: FlaxCLIPTextModel, + vae: FlaxAutoencoderKL, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Assume we have the two encoders + inputs = [] + for tokenizer in [self.tokenizer, self.tokenizer_2]: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + inputs.append(text_inputs.input_ids) + inputs = jnp.stack(inputs, axis=1) + return inputs + + def __call__( + self, + prompt_ids: jax.Array, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int = 50, + guidance_scale: Union[float, jax.Array] = 7.5, + height: Optional[int] = None, + width: Optional[int] = None, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + return_dict: bool = True, + output_type: str = None, + jit: bool = False, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float) and jit: + # Convert to a tensor so each device gets a copy. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + guidance_scale = guidance_scale[:, None] + + return_latents = output_type == "latent" + + if jit: + images = _p_generate( + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) + else: + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) + + if not return_dict: + return (images,) + + return FlaxStableDiffusionXLPipelineOutput(images=images) + + def get_embeddings(self, prompt_ids: jnp.array, params): + # We assume we have the two encoders + + # bs, encoder_input, seq_length + te_1_inputs = prompt_ids[:, 0, :] + te_2_inputs = prompt_ids[:, 1, :] + + prompt_embeds = self.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True) + prompt_embeds = prompt_embeds["hidden_states"][-2] + prompt_embeds_2_out = self.text_encoder_2( + te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True + ) + prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] + text_embeds = prompt_embeds_2_out["text_embeds"] + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) + return prompt_embeds, text_embeds + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, bs, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype) + return add_time_ids + + def _generate( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.Array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + return_latents=False, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # Encode input prompt + prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params) + + # Get unconditional embeddings + batch_size = prompt_embeds.shape[0] + if neg_prompt_ids is None: + neg_prompt_embeds = jnp.zeros_like(prompt_embeds) + negative_pooled_embeds = jnp.zeros_like(pooled_embeds) + else: + neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) + + add_time_ids = self._get_add_time_ids( + (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype + ) + + prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + + # Create random latents + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # Prepare scheduler state + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler_state.init_noise_sigma + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # Denoising loop + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + if return_latents: + return latents + + # Decode latents + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + +# Static argnums are pipe, num_inference_steps, height, width, return_latents. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None), + static_broadcasted_argnums=(0, 4, 5, 6, 10), +) +def _p_generate( + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, +): + return pipe._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py new file mode 100755 index 0000000..0783f44 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput, is_flax_available + + +@dataclass +class StableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +if is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Flax Stable Diffusion XL pipelines. + + Args: + images (`np.ndarray`) + Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. + """ + + images: np.ndarray diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py new file mode 100755 index 0000000..94b7283 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -0,0 +1,1300 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py new file mode 100755 index 0000000..29b5e11 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -0,0 +1,1494 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py new file mode 100755 index 0000000..d28a9af --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -0,0 +1,1732 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... use_safetensors=True, + ... ) + >>> pipe.to("cuda") + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def mask_pil_to_torch(mask, height, width): + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask = torch.from_numpy(mask) + return mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is not None: + masked_image = masked_image_latents + elif init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if self.denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 11.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py new file mode 100755 index 0000000..b59f231 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -0,0 +1,992 @@ +# Copyright 2024 Harutatsu Akiyama and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLInstructPix2PixPipeline + >>> from diffusers.utils import load_image + + >>> resolution = 768 + >>> image = load_image( + ... "https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + ... ).resize((resolution, resolution)) + >>> edit_instruction = "Turn sky into a cloudy one" + + >>> pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + ... "diffusers/sdxl-instructpix2pix-768", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> edited_image = pipe( + ... prompt=edit_instruction, + ... image=image, + ... height=resolution, + ... width=resolution, + ... guidance_scale=3.0, + ... image_guidance_scale=1.5, + ... num_inference_steps=30, + ... ).images[0] + >>> edited_image + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLInstructPix2PixPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, +): + r""" + Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + is_cosxl_edit (`bool`, *optional*): + When set the image latents are scaled. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + is_cosxl_edit: Optional[bool] = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + self.is_cosxl_edit = is_cosxl_edit + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + image = image.float() + self.upcast_vae() + + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + if image_latents.dtype != self.vae.dtype: + image_latents = image_latents.to(dtype=self.vae.dtype) + + if self.is_cosxl_edit: + image_latents = image_latents * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 100, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Image guidance scale is to push the generated image towards the initial image `image`. Image guidance + scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to + generate images that are closely linked to the source image `image`, usually at the expense of lower + image quality. This pipeline requires a value of at least `1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image, height=height, width=width).to(device) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + ) + + # 7. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if do_classifier_free_guidance: + # The extra concat similar to how it's done in SD InstructPix2Pix. + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0) + add_text_embeds = torch.cat( + [add_text_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0 + ) + add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/stable_diffusion_xl/watermark.py b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/watermark.py new file mode 100755 index 0000000..70d06bb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -0,0 +1,42 @@ +import numpy as np +import torch + +from ...utils import is_invisible_watermark_available + + +if is_invisible_watermark_available(): + from imwatermark import WatermarkEncoder + + +# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] + + +class StableDiffusionXLWatermarker: + def __init__(self): + self.watermark = WATERMARK_BITS + self.encoder = WatermarkEncoder() + + self.encoder.set_watermark("bits", self.watermark) + + def apply_watermark(self, images: torch.Tensor): + # can't encode images that are smaller than 256 + if images.shape[-1] < 256: + return images + + images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() + + # Convert RGB to BGR, which is the channel order expected by the watermark encoder. + images = images[:, :, :, ::-1] + + # Add watermark and convert BGR back to RGB + images = [self.encoder.encode(image, "dwtDct")[:, :, ::-1] for image in images] + + images = np.array(images) + + images = torch.from_numpy(images).permute(0, 3, 1, 2) + + images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) + return images diff --git a/diffusers/src/diffusers/pipelines/stable_video_diffusion/__init__.py b/diffusers/src/diffusers/pipelines/stable_video_diffusion/__init__.py new file mode 100755 index 0000000..3bd4dc7 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_video_diffusion/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + BaseOutput, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_stable_video_diffusion": [ + "StableVideoDiffusionPipeline", + "StableVideoDiffusionPipelineOutput", + ], + } + ) + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_video_diffusion import ( + StableVideoDiffusionPipeline, + StableVideoDiffusionPipelineOutput, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/diffusers/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py new file mode 100755 index 0000000..b14fdd4 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -0,0 +1,726 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import StableVideoDiffusionPipeline + >>> from diffusers.utils import load_image, export_to_video + + >>> pipe = StableVideoDiffusionPipeline.from_pretrained( + ... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg" + ... ) + >>> image = image.resize((1024, 576)) + + >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + >>> export_to_video(frames, "generated.mp4", fps=7) + ``` +""" + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class StableVideoDiffusionPipelineOutput(BaseOutput): + r""" + Output class for Stable Video Diffusion pipeline. + + Args: + frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]): + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + num_frames, height, width, num_channels)`. + """ + + frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor] + + +class StableVideoDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using Stable Video Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKLTemporalDecoder`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder + ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) + + def _encode_image( + self, + image: PipelineImageInput, + device: Union[str, torch.device], + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ) -> torch.Tensor: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image: torch.Tensor, + device: Union[str, torch.device], + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + return image_latents + + def _get_add_time_ids( + self, + fps: int, + motion_bucket_id: int, + noise_aug_strength: float, + dtype: torch.dtype, + batch_size: int, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + def prepare_latents( + self, + batch_size: int, + num_frames: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: Union[str, torch.device], + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + if isinstance(self.guidance_scale, (int, float)): + return self.guidance_scale > 1 + return self.guidance_scale.max() > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + sigmas: Optional[List[float]] = None, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: float = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): + Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, + 1]`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for + `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`). + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. This parameter is modulated by `strength`. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after + generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + Used for conditioning the amount of motion for the generation. The higher the number the more motion + will be in the video. + noise_aug_strength (`float`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the + init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the + expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality. + For lower memory usage, reduce `decode_chunk_size`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `pil`, `np` or `pt`. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step during inference. The function is called + with the following arguments: + `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is + returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is + returned. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = max_guidance_scale + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + + # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.video_processor.preprocess(image, height=height, width=width).to(device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image( + image, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 6. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas) + + # 7. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 8. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimension + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + added_time_ids=added_time_ids, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + +# resizing utils +# TODO: clean up later +def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: List[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out diff --git a/diffusers/src/diffusers/pipelines/t2i_adapter/__init__.py b/diffusers/src/diffusers/pipelines/t2i_adapter/__init__.py new file mode 100755 index 0000000..08c22a2 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/t2i_adapter/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_adapter"] = ["StableDiffusionAdapterPipeline"] + _import_structure["pipeline_stable_diffusion_xl_adapter"] = ["StableDiffusionXLAdapterPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline + from .pipeline_stable_diffusion_xl_adapter import StableDiffusionXLAdapterPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/diffusers/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py new file mode 100755 index 0000000..3cb7c26 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -0,0 +1,942 @@ +# Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +@dataclass +class StableDiffusionAdapterPipelineOutput(BaseOutput): + """ + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> from diffusers.utils import load_image + >>> import torch + >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter + + >>> image = load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png" + ... ) + + >>> color_palette = image.resize((8, 8)) + >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) + + >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16) + >>> pipe = StableDiffusionAdapterPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... adapter=adapter, + ... torch_dtype=torch.float16, + ... ) + + >>> pipe.to("cuda") + + >>> out_image = pipe( + ... "At night, glowing cubes in front of the beach", + ... image=color_palette, + ... ).images[0] + ``` +""" + + +def _preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" + ) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://arxiv.org/abs/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->adapter->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(adapter, (list, tuple)): + adapter = MultiAdapter(adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + adapter=adapter, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if isinstance(self.adapter, MultiAdapter): + if not isinstance(image, list): + raise ValueError( + "MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`." + ) + + if len(image) != len(self.adapter.adapters): + raise ValueError( + f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor + + return height, width + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + adapter_conditioning_scale: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + device = self._execution_device + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + self._guidance_scale = guidance_scale + + if isinstance(self.adapter, MultiAdapter): + adapter_input = [] + + for one_image in image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=self.adapter.dtype) + adapter_input.append(one_image) + else: + adapter_input = _preprocess_adapter_image(image, height, width) + adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + if isinstance(self.adapter, MultiAdapter): + adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if self.do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + down_intrablock_additional_residuals=[state.clone() for state in adapter_state], + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/diffusers/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py new file mode 100755 index 0000000..0ea197e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -0,0 +1,1277 @@ +# Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, DDPMScheduler + >>> from diffusers.utils import load_image + + >>> sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L") + + >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" + + >>> adapter = T2IAdapter.from_pretrained( + ... "Adapter/t2iadapter", + ... subfolder="sketch_sdxl_1.0", + ... torch_dtype=torch.float16, + ... adapter_type="full_adapter_xl", + ... ) + >>> scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") + + >>> pipe = StableDiffusionXLAdapterPipeline.from_pretrained( + ... model_id, adapter=adapter, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler + ... ).to("cuda") + + >>> generator = torch.manual_seed(42) + >>> sketch_image_out = pipe( + ... prompt="a photo of a dog in real world, high quality", + ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", + ... image=sketch_image, + ... generator=generator, + ... guidance_scale=7.5, + ... ).images[0] + ``` +""" + + +def _preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" + ) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLAdapterPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://arxiv.org/abs/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor + + return height, width + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + adapter_conditioning_scale: Union[float, List[float]] = 1.0, + adapter_conditioning_factor: float = 1.0, + clip_skip: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`] + instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + adapter_conditioning_factor (`float`, *optional*, defaults to 1.0): + The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is + `0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for + all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + + height, width = self._default_height_width(height, width, image) + device = self._execution_device + + if isinstance(self.adapter, MultiAdapter): + adapter_input = [] + + for one_image in image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=self.adapter.dtype) + adapter_input.append(one_image) + else: + adapter_input = _preprocess_adapter_image(image, height, width) + adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + ) + + self._guidance_scale = guidance_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3.1 Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare added time ids & embeddings & adapter features + if isinstance(self.adapter, MultiAdapter): + adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if self.do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + # Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + if i < int(num_inference_steps * adapter_conditioning_factor): + down_intrablock_additional_residuals = [state.clone() for state in adapter_state] + else: + down_intrablock_additional_residuals = None + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__init__.py new file mode 100755 index 0000000..8d8fdb9 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["TextToVideoSDPipelineOutput"] + _import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"] + _import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"] + _import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"] + _import_structure["pipeline_text_to_video_zero_sdxl"] = ["TextToVideoZeroSDXLPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_output import TextToVideoSDPipelineOutput + from .pipeline_text_to_video_synth import TextToVideoSDPipeline + from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline + from .pipeline_text_to_video_zero import TextToVideoZeroPipeline + from .pipeline_text_to_video_zero_sdxl import TextToVideoZeroSDXLPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py new file mode 100755 index 0000000..040bf0e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL +import torch + +from ...utils import ( + BaseOutput, +) + + +@dataclass +class TextToVideoSDPipelineOutput(BaseOutput): + """ + Output class for text-to-video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised + PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)` + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py new file mode 100755 index 0000000..cdd72b9 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -0,0 +1,643 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import TextToVideoSDPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames[0] + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +class TextToVideoSDPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet3DConditionModel`]): + A [`UNet3DConditionModel`] to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `torch.Tensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py new file mode 100755 index 0000000..92bf1d3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -0,0 +1,699 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import TextToVideoSDPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + >>> from diffusers.utils import export_to_video + + >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.to("cuda") + + >>> prompt = "spiderman running in the desert" + >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames[0] + >>> # safe low-res video + >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4") + + >>> # let's offload the text-to-image model + >>> pipe.to("cpu") + + >>> # and load the image-to-image model + >>> pipe = DiffusionPipeline.from_pretrained( + ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15" + ... ) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode + >>> pipe.vae.enable_slicing() + + >>> # now let's upscale it + >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] + + >>> # and denoise it + >>> video_frames = pipe(prompt, video=video, strength=0.6).frames[0] + >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4") + >>> video_path + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class VideoToVideoSDPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + Pipeline for text-guided video-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet3DConditionModel`]): + A [`UNet3DConditionModel`] to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None): + video = video.to(device=device, dtype=dtype) + + # change from (b, c, f, h, w) -> (b * f, c, w, h) + bsz, channel, frames, width, height = video.shape + video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + if video.shape[1] == 4: + init_latents = video + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + video: Union[List[np.ndarray], torch.Tensor] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + guidance_scale: float = 15.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + video (`List[np.ndarray]` or `torch.Tensor`): + `video` frames or tensor representing a video batch to be used as the starting point for the process. + Can also accept video latents as `image`, if passing latents directly, it will not be encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `video`. Must be between 0 and 1. `video` is used as a + starting point, adding more noise to it the larger the `strength`. The number of denoising steps + depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the + denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of + 1 essentially ignores `video`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in video generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `torch.Tensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess video + video = self.video_processor.preprocess_video(video) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py new file mode 100755 index 0000000..c95c7f1 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -0,0 +1,981 @@ +import copy +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from torch.nn.functional import grid_sample +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def rearrange_0(tensor, f): + F, C, H, W = tensor.size() + tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) + return tensor + + +def rearrange_1(tensor): + B, C, F, H, W = tensor.size() + return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) + + +def rearrange_3(tensor, f): + F, D, C = tensor.size() + return torch.reshape(tensor, (F // f, f, D, C)) + + +def rearrange_4(tensor): + B, F, D, C = tensor.size() + return torch.reshape(tensor, (B * F, D, C)) + + +class CrossFrameAttnProcessor: + """ + Cross frame attention processor. Each frame attends the first frame. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class CrossFrameAttnProcessor2_0: + """ + Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = max(1, key.size()[0] // self.batch_size) + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +@dataclass +class TextToVideoPipelineOutput(BaseOutput): + r""" + Output class for zero-shot text-to-video pipeline. + + Args: + images (`[List[PIL.Image.Image]`, `np.ndarray`]): + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`[List[bool]]`): + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +def coords_grid(batch, ht, wd, device): + # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def warp_single_latent(latent, reference_flow): + """ + Warp latent of a single frame with given flow + + Args: + latent: latent code of a single frame + reference_flow: flow which to warp the latent with + + Returns: + warped: warped latent + """ + _, _, H, W = reference_flow.size() + _, _, h, w = latent.size() + coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) + + coords_t0 = coords0 + reference_flow + coords_t0[:, 0] /= W + coords_t0[:, 1] /= H + + coords_t0 = coords_t0 * 2.0 - 1.0 + coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") + coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) + + warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") + return warped + + +def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): + """ + Create translation motion field + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + device: device + dtype: dtype + + Returns: + + """ + seq_length = len(frame_ids) + reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) + for fr_idx in range(seq_length): + reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) + reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) + return reference_flow + + +def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): + """ + Creates translation motion and warps the latents accordingly + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + latents: latent codes of frames + + Returns: + warped_latents: warped latents + """ + motion_field = create_motion_field( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + frame_ids=frame_ids, + device=latents.device, + dtype=latents.dtype, + ) + warped_latents = latents.clone().detach() + for i in range(len(warped_latents)): + warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) + return warped_latents + + +class TextToVideoZeroPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + Pipeline for zero-shot text-to-video generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet3DConditionModel`] to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`CLIPImageProcessor`]): + A [`CLIPImageProcessor`] to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def forward_loop(self, x_t0, t0, t1, generator): + """ + Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance. + + Args: + x_t0: + Latent code at time t0. + t0: + Timestep at t0. + t1: + Timestamp at t1. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + + Returns: + x_t1: + Forward process applied to x_t0 from time t0 to t1. + """ + eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) + alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) + x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps + return x_t1 + + def backward_loop( + self, + latents, + timesteps, + prompt_embeds, + guidance_scale, + callback, + callback_steps, + num_warmup_steps, + extra_step_kwargs, + cross_attention_kwargs=None, + ): + """ + Perform backward process given list of time steps. + + Args: + latents: + Latents at time timesteps[0]. + timesteps: + Time steps along which to perform backward process. + prompt_embeds: + Pre-generated text embeddings. + guidance_scale: + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + extra_step_kwargs: + Extra_step_kwargs. + cross_attention_kwargs: + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + num_warmup_steps: + number of warmup steps. + + Returns: + latents: + Latents of backward process output at time timesteps[-1]. + """ + do_classifier_free_guidance = guidance_scale > 1.0 + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + with self.progress_bar(total=num_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + return latents.clone().detach() + + # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int] = 8, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + motion_field_strength_x: float = 12, + motion_field_strength_y: float = 12, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + t0: int = 44, + t1: int = 47, + frame_ids: Optional[List[int]] = None, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + video_length (`int`, *optional*, defaults to 8): + The number of generated video frames. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in video generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `"latent"` and `"np"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a + [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput`] instead of + a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + motion_field_strength_x (`float`, *optional*, defaults to 12): + Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + motion_field_strength_y (`float`, *optional*, defaults to 12): + Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + t0 (`int`, *optional*, defaults to 44): + Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + t1 (`int`, *optional*, defaults to 47): + Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + frame_ids (`List[int]`, *optional*): + Indexes of the frames that are being generated. This is used when generating longer videos + chunk-by-chunk. + + Returns: + [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput`]: + The output contains a `ndarray` of the generated video, when `output_type` != `"latent"`, otherwise a + latent code of generated videos and a list of `bool`s indicating whether the corresponding generated + video contains "not-safe-for-work" (nsfw) content.. + """ + assert video_length > 0 + if frame_ids is None: + frame_ids = list(range(video_length)) + assert len(frame_ids) == video_length + + assert num_videos_per_prompt == 1 + + # set the processor + original_attn_proc = self.unet.attn_processors + processor = ( + CrossFrameAttnProcessor2_0(batch_size=2) + if hasattr(F, "scaled_dot_product_attention") + else CrossFrameAttnProcessor(batch_size=2) + ) + self.unet.set_attn_processor(processor) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + prompt_embeds_tuple = self.encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Perform the first backward process up to time T_1 + x_1_t1 = self.backward_loop( + timesteps=timesteps[: -t1 - 1], + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=num_warmup_steps, + ) + scheduler_copy = copy.deepcopy(self.scheduler) + + # Perform the second backward process up to time T_0 + x_1_t0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 : -t0 - 1], + prompt_embeds=prompt_embeds, + latents=x_1_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + + # Propagate first frame latents at time T_0 to remaining frames + x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) + + # Add motion in latents at time T_0 + x_2k_t0 = create_motion_field_and_warp_latents( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + latents=x_2k_t0, + frame_ids=frame_ids[1:], + ) + + # Perform forward process up to time T_1 + x_2k_t1 = self.forward_loop( + x_t0=x_2k_t0, + t0=timesteps[-t0 - 1].item(), + t1=timesteps[-t1 - 1].item(), + generator=generator, + ) + + # Perform backward process from time T_1 to 0 + x_1k_t1 = torch.cat([x_1_t1, x_2k_t1]) + b, l, d = prompt_embeds.size() + prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) + + self.scheduler = scheduler_copy + x_1k_0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 :], + prompt_embeds=prompt_embeds, + latents=x_1k_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + latents = x_1k_0 + + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + torch.cuda.empty_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + else: + image = self.decode_latents(latents) + # Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image diff --git a/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py new file mode 100755 index 0000000..c46a5ce --- /dev/null +++ b/diffusers/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py @@ -0,0 +1,1317 @@ +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from torch.nn.functional import grid_sample +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_invisible_watermark_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_0 +def rearrange_0(tensor, f): + F, C, H, W = tensor.size() + tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) + return tensor + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_1 +def rearrange_1(tensor): + B, C, F, H, W = tensor.size() + return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_3 +def rearrange_3(tensor, f): + F, D, C = tensor.size() + return torch.reshape(tensor, (F // f, f, D, C)) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_4 +def rearrange_4(tensor): + B, F, D, C = tensor.size() + return torch.reshape(tensor, (B * F, D, C)) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor +class CrossFrameAttnProcessor: + """ + Cross frame attention processor. Each frame attends the first frame. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor2_0 +class CrossFrameAttnProcessor2_0: + """ + Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = max(1, key.size()[0] // self.batch_size) + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +@dataclass +class TextToVideoSDXLPipelineOutput(BaseOutput): + """ + Output class for zero-shot text-to-video pipeline. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.coords_grid +def coords_grid(batch, ht, wd, device): + # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.warp_single_latent +def warp_single_latent(latent, reference_flow): + """ + Warp latent of a single frame with given flow + + Args: + latent: latent code of a single frame + reference_flow: flow which to warp the latent with + + Returns: + warped: warped latent + """ + _, _, H, W = reference_flow.size() + _, _, h, w = latent.size() + coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) + + coords_t0 = coords0 + reference_flow + coords_t0[:, 0] /= W + coords_t0[:, 1] /= H + + coords_t0 = coords_t0 * 2.0 - 1.0 + coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") + coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) + + warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") + return warped + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field +def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): + """ + Create translation motion field + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + device: device + dtype: dtype + + Returns: + + """ + seq_length = len(frame_ids) + reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) + for fr_idx in range(seq_length): + reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) + reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) + return reference_flow + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field_and_warp_latents +def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): + """ + Creates translation motion and warps the latents accordingly + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + latents: latent codes of frames + + Returns: + warped_latents: warped latents + """ + motion_field = create_motion_field( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + frame_ids=frame_ids, + device=latents.device, + dtype=latents.dtype, + ) + warped_latents = latents.clone().detach() + for i in range(len(warped_latents)): + warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) + return warped_latents + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class TextToVideoZeroSDXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for zero-shot text-to-video generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoZeroPipeline.forward_loop + def forward_loop(self, x_t0, t0, t1, generator): + """ + Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance. + + Args: + x_t0: + Latent code at time t0. + t0: + Timestep at t0. + t1: + Timestamp at t1. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + + Returns: + x_t1: + Forward process applied to x_t0 from time t0 to t1. + """ + eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) + alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) + x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps + return x_t1 + + def backward_loop( + self, + latents, + timesteps, + prompt_embeds, + guidance_scale, + callback, + callback_steps, + num_warmup_steps, + extra_step_kwargs, + add_text_embeds, + add_time_ids, + cross_attention_kwargs=None, + guidance_rescale: float = 0.0, + ): + """ + Perform backward process given list of time steps + + Args: + latents: + Latents at time timesteps[0]. + timesteps: + Time steps along which to perform backward process. + prompt_embeds: + Pre-generated text embeddings. + guidance_scale: + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + extra_step_kwargs: + Extra_step_kwargs. + cross_attention_kwargs: + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + num_warmup_steps: + number of warmup steps. + + Returns: + latents: latents of backward process output at time timesteps[-1] + """ + + do_classifier_free_guidance = guidance_scale > 1.0 + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + + with self.progress_bar(total=num_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + return latents.clone().detach() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + video_length: Optional[int] = 8, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + frame_ids: Optional[List[int]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + latents: Optional[torch.Tensor] = None, + motion_field_strength_x: float = 12, + motion_field_strength_y: float = 12, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + t0: int = 44, + t1: int = 47, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + video_length (`int`, *optional*, defaults to 8): + The number of generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + frame_ids (`List[int]`, *optional*): + Indexes of the frames that are being generated. This is used when generating longer videos + chunk-by-chunk. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + motion_field_strength_x (`float`, *optional*, defaults to 12): + Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + motion_field_strength_y (`float`, *optional*, defaults to 12): + Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + t0 (`int`, *optional*, defaults to 44): + Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + t1 (`int`, *optional*, defaults to 47): + Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + + Returns: + [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] or + `tuple`: [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + assert video_length > 0 + if frame_ids is None: + frame_ids = list(range(video_length)) + assert len(frame_ids) == video_length + + assert num_videos_per_prompt == 1 + + # set the processor + original_attn_proc = self.unet.attn_processors + processor = ( + CrossFrameAttnProcessor2_0(batch_size=2) + if hasattr(F, "scaled_dot_product_attention") + else CrossFrameAttnProcessor(batch_size=2) + ) + self.unet.set_attn_processor(processor) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + batch_size = ( + 1 if isinstance(prompt, str) else len(prompt) if isinstance(prompt, list) else prompt_embeds.shape[0] + ) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Perform the first backward process up to time T_1 + x_1_t1 = self.backward_loop( + timesteps=timesteps[: -t1 - 1], + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=num_warmup_steps, + add_text_embeds=add_text_embeds, + add_time_ids=add_time_ids, + ) + + scheduler_copy = copy.deepcopy(self.scheduler) + + # Perform the second backward process up to time T_0 + x_1_t0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 : -t0 - 1], + prompt_embeds=prompt_embeds, + latents=x_1_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + add_text_embeds=add_text_embeds, + add_time_ids=add_time_ids, + ) + + # Propagate first frame latents at time T_0 to remaining frames + x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) + + # Add motion in latents at time T_0 + x_2k_t0 = create_motion_field_and_warp_latents( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + latents=x_2k_t0, + frame_ids=frame_ids[1:], + ) + + # Perform forward process up to time T_1 + x_2k_t1 = self.forward_loop( + x_t0=x_2k_t0, + t0=timesteps[-t0 - 1].to(torch.long), + t1=timesteps[-t1 - 1].to(torch.long), + generator=generator, + ) + + # Perform backward process from time T_1 to 0 + latents = torch.cat([x_1_t1, x_2k_t1]) + + self.scheduler = scheduler_copy + timesteps = timesteps[-t1 - 1 :] + + b, l, d = prompt_embeds.size() + prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) + + b, k = add_text_embeds.size() + add_text_embeds = add_text_embeds[:, None].repeat(1, video_length, 1).reshape(b * video_length, k) + + b, k = add_time_ids.size() + add_time_ids = add_time_ids[:, None].repeat(1, video_length, 1).reshape(b * video_length, k) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + x_1k_0 = self.backward_loop( + timesteps=timesteps, + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + add_text_embeds=add_text_embeds, + add_time_ids=add_time_ids, + ) + + latents = x_1k_0 + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return TextToVideoSDXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return TextToVideoSDXLPipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/unclip/__init__.py b/diffusers/src/diffusers/pipelines/unclip/__init__.py new file mode 100755 index 0000000..c89e899 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unclip/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline + + _dummy_objects.update( + {"UnCLIPImageVariationPipeline": UnCLIPImageVariationPipeline, "UnCLIPPipeline": UnCLIPPipeline} + ) +else: + _import_structure["pipeline_unclip"] = ["UnCLIPPipeline"] + _import_structure["pipeline_unclip_image_variation"] = ["UnCLIPImageVariationPipeline"] + _import_structure["text_proj"] = ["UnCLIPTextProjModel"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_unclip import UnCLIPPipeline + from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline + from .text_proj import UnCLIPTextProjModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/unclip/pipeline_unclip.py b/diffusers/src/diffusers/pipelines/unclip/pipeline_unclip.py new file mode 100755 index 0000000..25c6739 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -0,0 +1,493 @@ +# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import functional as F +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel +from ...schedulers import UnCLIPScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution UNet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution UNet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]). + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). + + """ + + _exclude_from_cpu_offload = ["prior"] + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + model_cpu_offload_seq = "text_encoder->text_proj->decoder->super_res_first->super_res_last" + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_enc_hid_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_enc_hid_states.shape[1] + uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) + uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_enc_hid_states, text_mask + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prior_latents: Optional[torch.Tensor] = None, + decoder_latents: Optional[torch.Tensor] = None, + super_res_latents: Optional[torch.Tensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output` + and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prior_latents (`torch.Tensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text + outputs can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_enc_hid_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_enc_hid_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_enc_hid_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_enc_hid_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_enc_hid_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + # done super res + + self.maybe_free_model_hooks() + + # post processing + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/diffusers/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py new file mode 100755 index 0000000..2a0e7e9 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -0,0 +1,420 @@ +# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import PIL.Image +import torch +from torch.nn import functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import UNet2DConditionModel, UNet2DModel +from ...schedulers import UnCLIPScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPImageVariationPipeline(DiffusionPipeline): + """ + Pipeline to generate image variations from an input image using UnCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution UNet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution UNet. Used in the last step of the super resolution diffusion process. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). + """ + + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + model_cpu_offload_seq = "text_encoder->image_encoder->text_proj->decoder->super_res_first->super_res_last" + + def __init__( + self, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): + dtype = next(self.image_encoder.parameters()).dtype + + if image_embeddings is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + + image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor]] = None, + num_images_per_prompt: int = 1, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + decoder_latents: Optional[torch.Tensor] = None, + super_res_latents: Optional[torch.Tensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): + `Image` or tensor representing an image batch to be used as the starting point. If you provide a + tensor, it needs to be compatible with the [`CLIPImageProcessor`] + [configuration](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + Can be left as `None` only when `image_embeddings` are passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + image_embeddings (`torch.Tensor`, *optional*): + Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings + can be passed for tasks like image interpolations. `image` can be left as `None`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + else: + batch_size = image_embeddings.shape[0] + + prompt = [""] * batch_size + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = decoder_guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance + ) + + image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) + + # decoder + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + + # done super res + self.maybe_free_model_hooks() + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/src/diffusers/pipelines/unclip/text_proj.py b/diffusers/src/diffusers/pipelines/unclip/text_proj.py new file mode 100755 index 0000000..5a86d0c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unclip/text_proj.py @@ -0,0 +1,86 @@ +# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class UnCLIPTextProjModel(ModelMixin, ConfigMixin): + """ + Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the + decoder. + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 + """ + + @register_to_config + def __init__( + self, + *, + clip_extra_context_tokens: int = 4, + clip_embeddings_dim: int = 768, + time_embed_dim: int, + cross_attention_dim, + ): + super().__init__() + + self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) + + # parameters for additional clip time embeddings + self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) + self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) + + # parameters for encoder hidden states + self.clip_extra_context_tokens = clip_extra_context_tokens + self.clip_extra_context_tokens_proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) + self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): + if do_classifier_free_guidance: + # Add the classifier free guidance embeddings to the image embeddings + image_embeddings_batch_size = image_embeddings.shape[0] + classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) + classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( + image_embeddings_batch_size, -1 + ) + image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) + + # The image embeddings batch size and the text embeddings batch size are equal + assert image_embeddings.shape[0] == prompt_embeds.shape[0] + + batch_size = prompt_embeds.shape[0] + + # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and + # adding CLIP embeddings to the existing timestep embedding, ... + time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) + time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) + additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds + + # ... and by projecting CLIP embeddings into four + # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" + clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) + clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) + clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) + + text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) + text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) + text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) + + return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/diffusers/src/diffusers/pipelines/unidiffuser/__init__.py b/diffusers/src/diffusers/pipelines/unidiffuser/__init__.py new file mode 100755 index 0000000..1ac2b09 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unidiffuser/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) + + _dummy_objects.update( + {"ImageTextPipelineOutput": ImageTextPipelineOutput, "UniDiffuserPipeline": UniDiffuserPipeline} + ) +else: + _import_structure["modeling_text_decoder"] = ["UniDiffuserTextDecoder"] + _import_structure["modeling_uvit"] = ["UniDiffuserModel", "UTransformer2DModel"] + _import_structure["pipeline_unidiffuser"] = ["ImageTextPipelineOutput", "UniDiffuserPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) + else: + from .modeling_text_decoder import UniDiffuserTextDecoder + from .modeling_uvit import UniDiffuserModel, UTransformer2DModel + from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/diffusers/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py new file mode 100755 index 0000000..75e5d43 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -0,0 +1,296 @@ +from typing import Optional + +import numpy as np +import torch +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel +from transformers.modeling_utils import ModuleUtilsMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + """ + Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to + generate text from the UniDiffuser image-text embedding. + + Parameters: + prefix_length (`int`): + Max number of prefix tokens that will be supplied to the model. + prefix_inner_dim (`int`): + The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the + CLIP text encoder. + prefix_hidden_dim (`int`, *optional*): + Hidden dim of the MLP if we encode the prefix. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + """ + + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] + + @register_to_config + def __init__( + self, + prefix_length: int, + prefix_inner_dim: int, + prefix_hidden_dim: Optional[int] = None, + vocab_size: int = 50257, # Start of GPT2 config args + n_positions: int = 1024, + n_embd: int = 768, + n_layer: int = 12, + n_head: int = 12, + n_inner: Optional[int] = None, + activation_function: str = "gelu_new", + resid_pdrop: float = 0.1, + embd_pdrop: float = 0.1, + attn_pdrop: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_range: float = 0.02, + scale_attn_weights: bool = True, + use_cache: bool = True, + scale_attn_by_inverse_layer_idx: bool = False, + reorder_and_upcast_attn: bool = False, + ): + super().__init__() + + self.prefix_length = prefix_length + + if prefix_inner_dim != n_embd and prefix_hidden_dim is None: + raise ValueError( + f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" + f" `n_embd`: {n_embd} are not equal." + ) + + self.prefix_inner_dim = prefix_inner_dim + self.prefix_hidden_dim = prefix_hidden_dim + + self.encode_prefix = ( + nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) + if self.prefix_hidden_dim is not None + else nn.Identity() + ) + self.decode_prefix = ( + nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() + ) + + gpt_config = GPT2Config( + vocab_size=vocab_size, + n_positions=n_positions, + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + n_inner=n_inner, + activation_function=activation_function, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attn_pdrop=attn_pdrop, + layer_norm_epsilon=layer_norm_epsilon, + initializer_range=initializer_range, + scale_attn_weights=scale_attn_weights, + use_cache=use_cache, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + self.transformer = GPT2LMHeadModel(gpt_config) + + def forward( + self, + input_ids: torch.Tensor, + prefix_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + """ + Args: + input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): + Text tokens to use for inference. + prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): + Prefix embedding to preprend to the embedded tokens. + attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): + Attention mask for the prefix embedding. + labels (`torch.Tensor`, *optional*): + Labels to use for language modeling. + """ + embedding_text = self.transformer.transformer.wte(input_ids) + hidden = self.encode_prefix(prefix_embeds) + prefix_embeds = self.decode_prefix(hidden) + embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) + + if labels is not None: + dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) + labels = torch.cat((dummy_token, input_ids), dim=1) + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) + if self.prefix_hidden_dim is not None: + return out, hidden + else: + return out + + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def encode(self, prefix): + return self.encode_prefix(prefix) + + @torch.no_grad() + def generate_captions(self, features, eos_token_id, device): + """ + Generate captions given text embedding features. Returns list[L]. + + Args: + features (`torch.Tensor` of shape `(B, L, D)`): + Text embedding features to generate captions from. + eos_token_id (`int`): + The token ID of the EOS token for the text decoder model. + device: + Device to perform text generation on. + + Returns: + `List[str]`: A list of strings generated from the decoder model. + """ + + features = torch.split(features, 1, dim=0) + generated_tokens = [] + generated_seq_lengths = [] + for feature in features: + feature = self.decode_prefix(feature.to(device)) # back to the clip feature + # Only support beam search for now + output_tokens, seq_lengths = self.generate_beam( + input_embeds=feature, device=device, eos_token_id=eos_token_id + ) + generated_tokens.append(output_tokens[0]) + generated_seq_lengths.append(seq_lengths[0]) + generated_tokens = torch.stack(generated_tokens) + generated_seq_lengths = torch.stack(generated_seq_lengths) + return generated_tokens, generated_seq_lengths + + @torch.no_grad() + def generate_beam( + self, + input_ids=None, + input_embeds=None, + device=None, + beam_size: int = 5, + entry_length: int = 67, + temperature: float = 1.0, + eos_token_id: Optional[int] = None, + ): + """ + Generates text using the given tokenizer and text prompt or token embedding via beam search. This + implementation is based on the beam search implementation from the [original UniDiffuser + code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). + + Args: + eos_token_id (`int`, *optional*): + The token ID of the EOS token for the text decoder model. + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` + must be supplied. + input_embeds (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + An embedded representation to directly pass to the transformer as a prefix for beam search. One of + `input_ids` and `input_embeds` must be supplied. + device: + The device to perform beam search on. + beam_size (`int`, *optional*, defaults to `5`): + The number of best states to store during beam search. + entry_length (`int`, *optional*, defaults to `67`): + The number of iterations to run beam search. + temperature (`float`, *optional*, defaults to 1.0): + The temperature to use when performing the softmax over logits from the decoding model. + + Returns: + `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated + token sequences sorted by score in descending order, and the second element is the sequence lengths + corresponding to those sequences. + """ + # Generates text until stop_token is reached using beam search with the desired beam size. + stop_token_index = eos_token_id + tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + + if input_embeds is not None: + generated = input_embeds + else: + generated = self.transformer.transformer.wte(input_ids) + + for i in range(entry_length): + outputs = self.transformer(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + + scores = scores / seq_lengths + order = scores.argsort(descending=True) + # tokens tensors are already padded to max_seq_length + output_texts = [tokens[i] for i in order] + output_texts = torch.stack(output_texts, dim=0) + seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) + return output_texts, seq_lengths diff --git a/diffusers/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/diffusers/src/diffusers/pipelines/unidiffuser/modeling_uvit.py new file mode 100755 index 0000000..cb1514b --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -0,0 +1,1197 @@ +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...models.attention import FeedForward +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed +from ...models.modeling_outputs import Transformer2DModelOutput +from ...models.normalization import AdaLayerNorm +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect." + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, + \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for + generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + use_pos_embed=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.use_pos_embed = use_pos_embed + if self.use_pos_embed: + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.use_pos_embed: + return latent + self.pos_embed + else: + return latent + + +class SkipBlock(nn.Module): + def __init__(self, dim: int): + super().__init__() + + self.skip_linear = nn.Linear(2 * dim, dim) + + # Use torch.nn.LayerNorm for now, following the original code + self.norm = nn.LayerNorm(dim) + + def forward(self, x, skip): + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = self.norm(x) + + return x + + +# Modified to support both pre-LayerNorm and post-LayerNorm configurations +# Don't support AdaLayerNormZero for now +# Modified from diffusers.models.attention.BasicTransformerBlock +class UTransformerBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. + `pre_layer_norm = True`. + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = True, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + else: + norm_hidden_states = self.norm1(hidden_states) + else: + norm_hidden_states = hidden_states + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + attn_output = self.norm1(attn_output, timestep) + else: + attn_output = self.norm1(attn_output) + + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + else: + norm_hidden_states = hidden_states + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) + + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = self.norm3(hidden_states) + else: + norm_hidden_states = hidden_states + + ff_output = self.ff(norm_hidden_states) + + # Post-LayerNorm + if not self.pre_layer_norm: + ff_output = self.norm3(ff_output) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +# Like UTransformerBlock except with LayerNorms on the residual backbone of the block +# Modified from diffusers.models.attention.BasicTransformerBlock +class UniDiffuserBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the + LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser + implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = False, + final_dropout: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + ff_output = self.ff(hidden_states) + + hidden_states = ff_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + return hidden_states + + +# Modified from diffusers.models.transformer_2d.Transformer2DModel +# Modify the transformer block structure to be U-Net like following U-ViT +# Only supports patch-style input and torch.nn.LayerNorm currently +# https://github.com/baofff/U-ViT +class UTransformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared + to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, + similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] + layer and then reshaped to (b, t, d). + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = 2, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Input + # Only support patch input of shape (batch_size, num_channels, height, width) for now + assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." + + assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" + + # 2. Define input layers + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + + # 3. Define transformers blocks + # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, + # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in + # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). + # Quick hack to make the transformer block type configurable + if block_type == "unidiffuser": + block_cls = UniDiffuserBlock + else: + block_cls = UTransformerBlock + self.transformer_in_blocks = nn.ModuleList( + [ + block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + for d in range(num_layers // 2) + ] + ) + + self.transformer_mid_block = block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + + # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs + # before each transformer out_block. + self.transformer_out_blocks = nn.ModuleList( + [ + nn.ModuleDict( + { + "skip": SkipBlock( + inner_dim, + ), + "block": block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ), + } + ) + for d in range(num_layers // 2) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + + # Following the UniDiffuser U-ViT implementation, we process the transformer output with + # a LayerNorm layer with per-element affine params + self.norm_out = nn.LayerNorm(inner_dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + hidden_states_is_embedding: bool = False, + unpatchify: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): + Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will + ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the + transformer blocks. + unpatchify (`bool`, *optional*, defaults to `True`): + Whether to unpatchify the transformer output. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. Check inputs + + if not unpatchify and return_dict: + raise ValueError( + f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" + f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" + " rather than (batch_size, num_channels, height, width)." + ) + + # 1. Input + if not hidden_states_is_embedding: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + + # In ("downsample") blocks + skips = [] + for in_block in self.transformer_in_blocks: + hidden_states = in_block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + skips.append(hidden_states) + + # Mid block + hidden_states = self.transformer_mid_block(hidden_states) + + # Out ("upsample") blocks + for out_block in self.transformer_out_blocks: + hidden_states = out_block["skip"](hidden_states, skips.pop()) + hidden_states = out_block["block"]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic + hidden_states = self.norm_out(hidden_states) + # hidden_states = self.proj_out(hidden_states) + + if unpatchify: + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + else: + output = hidden_states + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class UniDiffuserModel(ModelMixin, ConfigMixin): + """ + Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a + modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the + CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). + + Parameters: + text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. + clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + ff_final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + use_data_type_embedding (`bool`, *optional*): + Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 + is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` + argument, which can either be `1` to use the weights trained on non-publically-available data or `0` + otherwise. This argument is subsequently embedded by the data type embedding, if used. + """ + + @register_to_config + def __init__( + self, + text_dim: int = 768, + clip_img_dim: int = 512, + num_text_tokens: int = 77, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + use_timestep_embedding=False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = True, + use_data_type_embedding: bool = False, + ): + super().__init__() + + # 0. Handle dimensions + self.inner_dim = num_attention_heads * attention_head_dim + + assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.patch_size = patch_size + # Assume image is square... + self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) + + # 1. Define input layers + # 1.1 Input layers for text and image input + # For now, only support patch input for VAE latent image input + self.vae_img_in = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) + self.text_in = nn.Linear(text_dim, self.inner_dim) + + # 1.2. Timestep embeddings for t_img, t_text + self.timestep_img_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_img_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + self.timestep_text_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_text_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + # 1.3. Positional embedding + self.num_text_tokens = num_text_tokens + self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) + self.pos_embed_drop = nn.Dropout(p=dropout) + trunc_normal_(self.pos_embed, std=0.02) + + # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary + self.use_data_type_embedding = use_data_type_embedding + if self.use_data_type_embedding: + self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) + self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) + + # 2. Define transformer blocks + self.transformer = UTransformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + patch_size=patch_size, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + block_type=block_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + use_patch_pos_embed=use_patch_pos_embed, + ff_final_dropout=ff_final_dropout, + ) + + # 3. Define output layers + patch_dim = (patch_size**2) * out_channels + self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) + self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) + self.text_out = nn.Linear(self.inner_dim, text_dim) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + def forward( + self, + latent_image_embeds: torch.Tensor, + image_embeds: torch.Tensor, + prompt_embeds: torch.Tensor, + timestep_img: Union[torch.Tensor, float, int], + timestep_text: Union[torch.Tensor, float, int], + data_type: Optional[Union[torch.Tensor, float, int]] = 1, + encoder_hidden_states=None, + cross_attention_kwargs=None, + ): + """ + Args: + latent_image_embeds (`torch.Tensor` of shape `(batch size, latent channels, height, width)`): + Latent image representation from the VAE encoder. + image_embeds (`torch.Tensor` of shape `(batch size, 1, clip_img_dim)`): + CLIP-embedded image representation (unsqueezed in the first dimension). + prompt_embeds (`torch.Tensor` of shape `(batch size, seq_len, text_dim)`): + CLIP-embedded text representation. + timestep_img (`torch.long` or `float` or `int`): + Current denoising step for the image. + timestep_text (`torch.long` or `float` or `int`): + Current denoising step for the text. + data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): + Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, + or `0` otherwise. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + + + Returns: + `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE + image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text + embedding. + """ + batch_size = latent_image_embeds.shape[0] + + # 1. Input + # 1.1. Map inputs to shape (B, N, inner_dim) + vae_hidden_states = self.vae_img_in(latent_image_embeds) + clip_hidden_states = self.clip_img_in(image_embeds) + text_hidden_states = self.text_in(prompt_embeds) + + num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) + + # 1.2. Encode image timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_img): + timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) + + timestep_img_token = self.timestep_img_proj(timestep_img) + # t_img_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_img_token = timestep_img_token.to(dtype=self.dtype) + timestep_img_token = self.timestep_img_embed(timestep_img_token) + timestep_img_token = timestep_img_token.unsqueeze(dim=1) + + # 1.3. Encode text timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_text): + timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) + + timestep_text_token = self.timestep_text_proj(timestep_text) + # t_text_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_text_token = timestep_text_token.to(dtype=self.dtype) + timestep_text_token = self.timestep_text_embed(timestep_text_token) + timestep_text_token = timestep_text_token.unsqueeze(dim=1) + + # 1.4. Concatenate all of the embeddings together. + if self.use_data_type_embedding: + assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" + if not torch.is_tensor(data_type): + data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) + + data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) + hidden_states = torch.cat( + [ + timestep_img_token, + timestep_text_token, + data_type_token, + text_hidden_states, + clip_hidden_states, + vae_hidden_states, + ], + dim=1, + ) + else: + hidden_states = torch.cat( + [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], + dim=1, + ) + + # 1.5. Prepare the positional embeddings and add to hidden states + # Note: I think img_vae should always have the proper shape, so there's no need to interpolate + # the position embeddings. + if self.use_data_type_embedding: + pos_embed = torch.cat( + [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 + ) + else: + pos_embed = self.pos_embed + hidden_states = hidden_states + pos_embed + hidden_states = self.pos_embed_drop(hidden_states) + + # 2. Blocks + hidden_states = self.transformer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=None, + class_labels=None, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + hidden_states_is_embedding=True, + unpatchify=False, + )[0] + + # 3. Output + # Split out the predicted noise representation. + if self.use_data_type_embedding: + ( + t_img_token_out, + t_text_token_out, + data_type_token_out, + text_out, + img_clip_out, + img_vae_out, + ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) + else: + t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( + (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 + ) + + img_vae_out = self.vae_img_out(img_vae_out) + + # unpatchify + height = width = int(img_vae_out.shape[1] ** 0.5) + img_vae_out = img_vae_out.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) + img_vae_out = img_vae_out.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + img_clip_out = self.clip_img_out(img_clip_out) + + text_out = self.text_out(text_out) + + return img_vae_out, img_clip_out, text_out diff --git a/diffusers/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/diffusers/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py new file mode 100755 index 0000000..4f65caf --- /dev/null +++ b/diffusers/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -0,0 +1,1420 @@ +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.outputs import BaseOutput +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .modeling_text_decoder import UniDiffuserTextDecoder +from .modeling_uvit import UniDiffuserModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# New BaseOutput child class for joint image-text output +@dataclass +class ImageTextPipelineOutput(BaseOutput): + """ + Output class for joint image-text pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + text (`List[str]` or `List[List[str]]`) + List of generated text strings of length `batch_size` or a list of list of strings whose outer list has + length `batch_size`. + """ + + images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + text: Optional[Union[List[str], List[List[str]]]] + + +class UniDiffuserPipeline(DiffusionPipeline): + r""" + Pipeline for a bimodal image-text model which supports unconditional text and image generation, text-conditioned + image generation, image-conditioned text generation, and joint image-text generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. This + is part of the UniDiffuser image representation along with the CLIP vision encoding. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + image_encoder ([`CLIPVisionModel`]): + A [`~transformers.CLIPVisionModel`] to encode images as part of its image representation along with the VAE + latent representation. + image_processor ([`CLIPImageProcessor`]): + [`~transformers.CLIPImageProcessor`] to preprocess an image before CLIP encoding it with `image_encoder`. + clip_tokenizer ([`CLIPTokenizer`]): + A [`~transformers.CLIPTokenizer`] to tokenize the prompt before encoding it with `text_encoder`. + text_decoder ([`UniDiffuserTextDecoder`]): + Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser + embedding. + text_tokenizer ([`GPT2Tokenizer`]): + A [`~transformers.GPT2Tokenizer`] to decode text for text generation; used along with the `text_decoder`. + unet ([`UniDiffuserModel`]): + A [U-ViT](https://github.com/baofff/U-ViT) model with UNNet-style skip connections between transformer + layers to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The + original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. + """ + + # TODO: support for moving submodules for components with enable_model_cpu_offload + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae->text_decoder" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModelWithProjection, + clip_image_processor: CLIPImageProcessor, + clip_tokenizer: CLIPTokenizer, + text_decoder: UniDiffuserTextDecoder, + text_tokenizer: GPT2Tokenizer, + unet: UniDiffuserModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: + raise ValueError( + f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" + f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + clip_image_processor=clip_image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.num_channels_latents = vae.config.latent_channels + self.text_encoder_seq_len = text_encoder.config.max_position_embeddings + self.text_encoder_hidden_size = text_encoder.config.hidden_size + self.image_encoder_projection_dim = image_encoder.config.projection_dim + self.unet_resolution = unet.config.sample_size + + self.text_intermediate_dim = self.text_encoder_hidden_size + if self.text_decoder.prefix_hidden_dim is not None: + self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim + + self.mode = None + + # TODO: handle safety checking? + self.safety_checker = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): + r""" + Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set + mode will be used. + """ + prompt_available = (prompt is not None) or (prompt_embeds is not None) + image_available = image is not None + input_available = prompt_available or image_available + + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + full_latents_available = latents is not None + image_latents_available = vae_latents_available and clip_latents_available + all_indv_latents_available = prompt_latents_available and image_latents_available + + if self.mode is not None: + # Preferentially use the mode set by the user + mode = self.mode + elif prompt_available: + mode = "text2img" + elif image_available: + mode = "img2text" + else: + # Neither prompt nor image supplied, infer based on availability of latents + if full_latents_available or all_indv_latents_available: + mode = "joint" + elif prompt_latents_available: + mode = "text" + elif image_latents_available: + mode = "img" + else: + # No inputs or latents available + mode = "joint" + + # Give warnings for ambiguous cases + if self.mode is None and prompt_available and image_available: + logger.warning( + f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," + f" defaulting to mode '{mode}'." + ) + + if self.mode is None and not input_available: + if vae_latents_available != clip_latents_available: + # Exactly one of vae_latents and clip_latents is supplied + logger.warning( + f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" + f" are expected to be supplied. Defaulting to mode '{mode}'." + ) + elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: + # No inputs or latents supplied + logger.warning( + f"No inputs or latents have been supplied, and mode has not been manually set," + f" defaulting to mode '{mode}'." + ) + + return mode + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Functions to manually set the mode + def set_text_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") text generation.""" + self.mode = "text" + + def set_image_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") image generation.""" + self.mode = "img" + + def set_text_to_image_mode(self): + r"""Manually set the generation mode to text-conditioned image generation.""" + self.mode = "text2img" + + def set_image_to_text_mode(self): + r"""Manually set the generation mode to image-conditioned text generation.""" + self.mode = "img2text" + + def set_joint_mode(self): + r"""Manually set the generation mode to unconditional joint image-text generation.""" + self.mode = "joint" + + def reset_mode(self): + r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" + self.mode = None + + def _infer_batch_size( + self, + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ): + r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" + if num_images_per_prompt is None: + num_images_per_prompt = 1 + if num_prompts_per_image is None: + num_prompts_per_image = 1 + + assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" + assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" + + if mode in ["text2img"]: + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + # Either prompt or prompt_embeds must be present for text2img. + batch_size = prompt_embeds.shape[0] + multiplier = num_images_per_prompt + elif mode in ["img2text"]: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + else: + # Image must be available and type either PIL.Image.Image or torch.Tensor. + # Not currently supporting something like image_embeds. + batch_size = image.shape[0] + multiplier = num_prompts_per_image + elif mode in ["img"]: + if vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + multiplier = num_images_per_prompt + elif mode in ["text"]: + if prompt_latents is not None: + batch_size = prompt_latents.shape[0] + else: + batch_size = 1 + multiplier = num_prompts_per_image + elif mode in ["joint"]: + if latents is not None: + batch_size = latents.shape[0] + elif prompt_latents is not None: + batch_size = prompt_latents.shape[0] + elif vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + + if num_images_per_prompt == num_prompts_per_image: + multiplier = num_images_per_prompt + else: + multiplier = min(num_images_per_prompt, num_prompts_per_image) + logger.warning( + f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" + f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" + f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." + ) + return batch_size, multiplier + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer) + + text_inputs = self.clip_tokenizer( + prompt, + padding="max_length", + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.clip_tokenizer.batch_decode( + untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.clip_tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents + # Add num_prompts_per_image argument, sample from autoencoder moment distribution + def encode_image_vae_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + do_classifier_free_guidance, + generator=None, + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + * self.vae.config.scaling_factor + for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + # Scale image_latents by the VAE's scaling factor + image_latents = image_latents * self.vae.config.scaling_factor + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def encode_image_clip_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + generator=None, + ): + # Map image to CLIP embedding. + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + preprocessed_image = self.clip_image_processor.preprocess( + image, + return_tensors="pt", + ) + preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list): + image_latents = [ + self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.image_encoder(**preprocessed_image).image_embeds + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return image_latents + + def prepare_text_latents( + self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded prompt. + shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shace (B, L, D) + latents = latents.repeat(num_images_per_prompt, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. + def prepare_image_vae_latents( + self, + batch_size, + num_prompts_per_image, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size * num_prompts_per_image, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, C, H, W) + latents = latents.repeat(num_prompts_per_image, 1, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_clip_latents( + self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded image. + shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, L, D) + latents = latents.repeat(num_prompts_per_image, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_text_latents(self, text_latents, device): + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + generated_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + return generated_text + + def _split(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) + and (B, 1, clip_img_dim) + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + + img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + return img_vae, img_clip + + def _combine(self, img_vae, img_clip): + r""" + Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, + clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + return torch.concat([img_vae, img_clip], dim=-1) + + def _split_joint(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, + img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is + of shape (B, text_seq_len, text_dim). + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_intermediate_dim + + img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) + return img_vae, img_clip, text + + def _combine_joint(self, img_vae, img_clip, text): + r""" + Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, + clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, + C * H * W + L_img * clip_img_dim + L_text * text_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + text = torch.reshape(text, (text.shape[0], -1)) + return torch.concat([img_vae, img_clip, text], dim=-1) + + def _get_noise_pred( + self, + mode, + latents, + t, + prompt_embeds, + img_vae, + img_clip, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ): + r""" + Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. + """ + if mode == "joint": + # Joint text-image generation + img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type + ) + + x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) + + if guidance_scale <= 1.0: + return x_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + _, _, text_out_uncond = self.unet( + img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) + + return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond + elif mode == "text2img": + # Text-conditioned image generation + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type + ) + + img_out = self._combine(img_vae_out, img_clip_out) + + if guidance_scale <= 1.0: + return img_out + + # Classifier-free guidance + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) + + return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond + elif mode == "img2text": + # Image-conditioned text generation + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type + ) + + if guidance_scale <= 1.0: + return text_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond + elif mode == "text": + # Unconditional ("marginal") text generation (no CFG) + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return text_out + elif mode == "img": + # Unconditional ("marginal") image generation (no CFG) + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, + img_clip_latents, + prompt_embeds, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out = self._combine(img_vae_out, img_clip_out) + return img_out + + def check_latents_shape(self, latents_name, latents, expected_shape): + latents_shape = latents.shape + expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension + expected_shape_str = ", ".join(str(dim) for dim in expected_shape) + if len(latents_shape) != expected_num_dims: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {len(latents_shape)} dimensions." + ) + for i in range(1, expected_num_dims): + if latents_shape[i] != expected_shape[i - 1]: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." + ) + + def check_inputs( + self, + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + latents=None, + prompt_latents=None, + vae_latents=None, + clip_latents=None, + ): + # Check inputs before running the generative process. + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if mode == "text2img": + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if mode == "img2text": + if image is None: + raise ValueError("`img2text` mode requires an image to be provided.") + + # Check provided latents + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + full_latents_available = latents is not None + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + + if full_latents_available: + individual_latents_available = ( + prompt_latents is not None or vae_latents is not None or clip_latents is not None + ) + if individual_latents_available: + logger.warning( + "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" + " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." + ) + # Check shape of full latents + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size + latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim + latents_expected_shape = (latents_dim,) + self.check_latents_shape("latents", latents, latents_expected_shape) + + # Check individual latent shapes, if present + if prompt_latents_available: + prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) + self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) + + if vae_latents_available: + vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) + self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) + + if clip_latents_available: + clip_latents_expected_shape = (1, self.image_encoder_projection_dim) + self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) + + if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: + if vae_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" + f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." + ) + + if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: + if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" + f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" + f" != {clip_latents.shape[0]}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + data_type: Optional[int] = 1, + num_inference_steps: int = 50, + guidance_scale: float = 8.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + num_prompts_per_image: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_latents: Optional[torch.Tensor] = None, + vae_latents: Optional[torch.Tensor] = None, + clip_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + Required for text-conditioned image generation (`text2img`) mode. + image (`torch.Tensor` or `PIL.Image.Image`, *optional*): + `Image` or tensor representing an image batch. Required for image-conditioned text generation + (`img2text`) mode. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + data_type (`int`, *optional*, defaults to 1): + The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type + embedding; this is added for compatibility with the + [UniDiffuser-v1](https://huggingface.co/thu-ml/unidiffuser-v1) checkpoint. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). Used in + text-conditioned image generation (`text2img`) mode. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and + `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples are generated. + num_prompts_per_image (`int`, *optional*, defaults to 1): + The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and + `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples are generated. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for joint + image-text generation. Can be used to tweak the same generation with different prompts. If not + provided, a latents tensor is generated by sampling using the supplied random `generator`. This assumes + a full set of VAE, CLIP, and text latents, if supplied, overrides the value of `prompt_latents`, + `vae_latents`, and `clip_latents`. + prompt_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for text + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + vae_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + clip_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. Used in text-conditioned + image generation (`text2img`) mode. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are be generated from the `negative_prompt` input argument. Used + in text-conditioned image generation (`text2img`) mode. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImageTextPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.unidiffuser.ImageTextPipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images and the second element + is a list of generated texts. + """ + + # 0. Default height and width to unet + height = height or self.unet_resolution * self.vae_scale_factor + width = width or self.unet_resolution * self.vae_scale_factor + + # 1. Check inputs + # Recalculate mode for each call to the pipeline. + mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) + self.check_inputs( + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + + # 2. Define call parameters + batch_size, multiplier = self._infer_batch_size( + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + device = self._execution_device + reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # Note that this differs from the formulation in the unidiffusers paper! + do_classifier_free_guidance = guidance_scale > 1.0 + + # check if scheduler is in sigmas space + # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt, if available; otherwise prepare text latents + if latents is not None: + # Overwrite individual latents + vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) + + if mode in ["text2img"]: + # 3.1. Encode input prompt, if available + assert prompt is not None or prompt_embeds is not None + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=multiplier, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # if do_classifier_free_guidance: + # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + else: + # 3.2. Prepare text latent variables, if input not available + prompt_embeds = self.prepare_text_latents( + batch_size=batch_size, + num_images_per_prompt=multiplier, + seq_len=self.text_encoder_seq_len, + hidden_size=self.text_encoder_hidden_size, + dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision + device=device, + generator=generator, + latents=prompt_latents, + ) + + if reduce_text_emb_dim: + prompt_embeds = self.text_decoder.encode(prompt_embeds) + + # 4. Encode image, if available; otherwise prepare image latents + if mode in ["img2text"]: + # 4.1. Encode images, if available + assert image is not None, "`img2text` requires a conditioning image" + # Encode image using VAE + image_vae = self.image_processor.preprocess(image) + height, width = image_vae.shape[-2:] + image_vae_latents = self.encode_image_vae_latents( + image=image_vae, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG + generator=generator, + ) + + # Encode image using CLIP + image_clip_latents = self.encode_image_clip_latents( + image=image, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) + image_clip_latents = image_clip_latents.unsqueeze(1) + else: + # 4.2. Prepare image latent variables, if input not available + # Prepare image VAE latents in latent space + image_vae_latents = self.prepare_image_vae_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=vae_latents, + ) + + # Prepare image CLIP latents + image_clip_latents = self.prepare_image_clip_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + clip_img_dim=self.image_encoder_projection_dim, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=clip_latents, + ) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + # max_timestep = timesteps[0] + max_timestep = self.scheduler.config.num_train_timesteps + + # 6. Prepare latent variables + if mode == "joint": + latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) + elif mode in ["text2img", "img"]: + latents = self._combine(image_vae_latents, image_clip_latents) + elif mode in ["img2text", "text"]: + latents = prompt_embeds + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # predict the noise residual + # Also applies classifier-free guidance as described in the UniDiffuser paper + noise_pred = self._get_noise_pred( + mode, + latents, + t, + prompt_embeds, + image_vae_latents, + image_clip_latents, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + image = None + text = None + if mode == "joint": + image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) + + if not output_type == "latent": + # Map latent VAE image back to pixel space + image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = image_vae_latents + + text = self.decode_text_latents(text_latents, device) + elif mode in ["text2img", "img"]: + image_vae_latents, image_clip_latents = self._split(latents, height, width) + + if not output_type == "latent": + # Map latent VAE image back to pixel space + image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = image_vae_latents + elif mode in ["img2text", "text"]: + text_latents = latents + text = self.decode_text_latents(text_latents, device) + + self.maybe_free_model_hooks() + + # 10. Postprocess the image, if necessary + if image is not None: + do_denormalize = [True] * image.shape[0] + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, text) + + return ImageTextPipelineOutput(images=image, text=text) diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/__init__.py b/diffusers/src/diffusers/pipelines/wuerstchen/__init__.py new file mode 100755 index 0000000..ddb852d --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/__init__.py @@ -0,0 +1,56 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"] + _import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"] + _import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"] + _import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"] + _import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"] + _import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modeling_paella_vq_model import PaellaVQModel + from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt + from .modeling_wuerstchen_prior import WuerstchenPrior + from .pipeline_wuerstchen import WuerstchenDecoderPipeline + from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline + from .pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPriorPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py new file mode 100755 index 0000000..b2cf8cb --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022 Dominic Rampas MIT License +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer +from ...models.modeling_utils import ModelMixin +from ...models.vq_model import VQEncoderOutput +from ...utils.accelerate_utils import apply_forward_hook + + +class MixingResidualBlock(nn.Module): + """ + Residual block with mixing used by Paella's VQ-VAE. + """ + + def __init__(self, inp_channels, embed_dim): + super().__init__() + # depthwise + self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) + ) + + # channelwise + self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + def forward(self, x): + mods = self.gammas + x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + return x + + +class PaellaVQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from Paella model. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. + levels (int, *optional*, defaults to 2): Number of levels in the model. + bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. + embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model. + latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model. + num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. + scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_down_scale_factor: int = 2, + levels: int = 2, + bottleneck_blocks: int = 12, + embed_dim: int = 384, + latent_channels: int = 4, + num_vq_embeddings: int = 8192, + scale_factor: float = 0.3764, + ): + super().__init__() + + c_levels = [embed_dim // (2**i) for i in reversed(range(levels))] + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(up_down_scale_factor), + nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append( + nn.Sequential( + nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1 + ) + ) + self.down_blocks = nn.Sequential(*down_blocks) + + # Vector Quantizer + self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25) + + # Decoder blocks + up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d( + c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 + ) + ) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), + nn.PixelShuffle(up_down_scale_factor), + ) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.in_block(x) + h = self.down_blocks(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + @apply_forward_hook + def decode( + self, h: torch.Tensor, force_not_quantize: bool = True, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + if not force_not_quantize: + quant, _, _ = self.vquantizer(h) + else: + quant = h + + x = self.up_blocks(quant) + dec = self.out_block(x) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py new file mode 100755 index 0000000..73e71b3 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn + +from ...models.attention_processor import Attention + + +class WuerstchenLayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + return x.permute(0, 3, 1, 2) + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep): + super().__init__() + + self.mapper = nn.Linear(c_timestep, c * 2) + + def forward(self, x, t): + a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) + return x * (1 + a) + b + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + + self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) + x = self.channelwise(x).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + + self.self_attn = self_attn + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) + return x diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py new file mode 100755 index 0000000..6c06cc0 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py @@ -0,0 +1,254 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm + + +class WuerstchenDiffNeXt(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + c_in=4, + c_out=4, + c_r=64, + patch_size=2, + c_cond=1024, + c_hidden=[320, 640, 1280, 1280], + nhead=[-1, 10, 20, 20], + blocks=[4, 4, 14, 4], + level_config=["CT", "CTA", "CTA", "CTA"], + inject_effnet=[False, True, True, True], + effnet_embd=16, + clip_embd=1024, + kernel_size=3, + dropout=0.1, + ): + super().__init__() + self.c_r = c_r + self.c_cond = c_cond + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + + # CONDITIONING + self.clip_mapper = nn.Linear(clip_embd, c_cond) + self.effnet_mappers = nn.ModuleList( + [ + nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None + for inject in inject_effnet + list(reversed(inject_effnet)) + ] + ) + self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): + if block_type == "C": + return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "A": + return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout) + elif block_type == "T": + return TimestepBlock(c_hidden, c_r) + else: + raise ValueError(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + for i in range(len(c_hidden)): + down_block = nn.ModuleList() + if i > 0: + down_block.append( + nn.Sequential( + WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + ) + ) + for _ in range(blocks[i]): + for block_type in level_config[i]: + c_skip = c_cond if inject_effnet[i] else 0 + down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + self.down_blocks.append(down_block) + + # -- up blocks + self.up_blocks = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + up_block = nn.ModuleList() + for j in range(blocks[i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + c_skip += c_cond if inject_effnet[i] else 0 + up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + if i > 0: + up_block.append( + nn.Sequential( + WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + ) + ) + self.up_blocks.append(up_block) + + # OUTPUT + self.clf = nn.Sequential( + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) + + def _init_weights(self, m): + # General init + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + for mapper in self.effnet_mappers: + if mapper is not None: + nn.init.normal_(mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlockStageB): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks)) + elif isinstance(block, TimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def gen_c_embeddings(self, clip): + clip = self.clip_mapper(clip) + clip = self.seq_norm(clip) + return clip + + def _down_encode(self, x, r_embed, effnet, clip=None): + level_outputs = [] + for i, down_block in enumerate(self.down_blocks): + effnet_c = None + for block in down_block: + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[i] is not None: + dtype = effnet.dtype + effnet_c = self.effnet_mappers[i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ).to(dtype) + ) + skip = effnet_c if self.effnet_mappers[i] is not None else None + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, effnet, clip=None): + x = level_outputs[0] + for i, up_block in enumerate(self.up_blocks): + effnet_c = None + for j, block in enumerate(up_block): + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None: + dtype = effnet.dtype + effnet_c = self.effnet_mappers[len(self.down_blocks) + i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ).to(dtype) + ) + skip = level_outputs[i] if j == 0 and i > 0 else None + if effnet_c is not None: + if skip is not None: + skip = torch.cat([skip, effnet_c], dim=1) + else: + skip = effnet_c + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + return x + + def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True): + if x_cat is not None: + x = torch.cat([x, x_cat], dim=1) + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + if clip is not None: + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x_in = x + x = self.embedding(x) + level_outputs = self._down_encode(x, r_embed, effnet, clip) + x = self._up_decode(level_outputs, r_embed, effnet, clip) + a, b = self.clf(x).chunk(2, dim=1) + b = b.sigmoid() * (1 - eps * 2) + eps + if return_noise: + return (x_in - a) / b + else: + return a, b + + +class ResBlockStageB(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py new file mode 100755 index 0000000..edb0c1e --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ...models.modeling_utils import ModelMixin +from ...utils import is_torch_version +from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm + + +class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + unet_name = "prior" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): + super().__init__() + + self.c_r = c_r + self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.cond_mapper = nn.Sequential( + nn.Linear(c_cond, c), + nn.LeakyReLU(0.2), + nn.Linear(c, c), + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append(ResBlock(c, dropout=dropout)) + self.blocks.append(TimestepBlock(c, c_r)) + self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) + self.out = nn.Sequential( + WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), + nn.Conv2d(c, c_in * 2, kernel_size=1), + ) + + self.gradient_checkpointing = False + self.set_default_attn_processor() + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def forward(self, x, r, c): + x_in = x + x = self.projection(x) + c_embed = self.cond_mapper(c) + r_embed = self.gen_r_embedding(r) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + for block in self.blocks: + if isinstance(block, AttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, c_embed, use_reentrant=False + ) + elif isinstance(block, TimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed, use_reentrant=False + ) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + else: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed) + elif isinstance(block, TimestepBlock): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) + else: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + a, b = self.out(x).chunk(2, dim=1) + return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py new file mode 100755 index 0000000..b084214 --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -0,0 +1,438 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import deprecate, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_paella_vq_model import PaellaVQModel +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline + + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to( + ... "cuda" + ... ) + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) + ``` +""" + + +class WuerstchenDecoderPipeline(DiffusionPipeline): + """ + Pipeline for generating images from the Wuerstchen model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The CLIP tokenizer. + text_encoder (`CLIPTextModel`): + The CLIP text encoder. + decoder ([`WuerstchenDiffNeXt`]): + The WuerstchenDiffNeXt unet decoder. + vqgan ([`PaellaVQModel`]): + The VQGAN model. + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_dim_scale (float, `optional`, defaults to 10.67): + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. + """ + + model_cpu_offload_seq = "text_encoder->decoder->vqgan" + _callback_tensor_inputs = [ + "latents", + "text_encoder_hidden_states", + "negative_prompt_embeds", + "image_embeddings", + ] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: WuerstchenDiffNeXt, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + latent_dim_scale: float = 10.67, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + self.register_to_config(latent_dim_scale=latent_dim_scale) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) + text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_text_encoder_hidden_states = None + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) + + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + return text_encoder_hidden_states, uncond_text_encoder_hidden_states + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeddings: Union[torch.Tensor, List[torch.Tensor]], + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 12, + timesteps: Optional[List[float]] = None, + guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embedding (`torch.Tensor` or `List[torch.Tensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # 0. Define commonly used variables + device = self._execution_device + dtype = self.decoder.dtype + self._guidance_scale = guidance_scale + + # 1. Check inputs. Raise error if not correct + if not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if self.do_classifier_free_guidance: + if negative_prompt is not None and not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) + + if isinstance(image_embeddings, list): + image_embeddings = torch.cat(image_embeddings, dim=0) + if isinstance(image_embeddings, np.ndarray): + image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype) + if not isinstance(image_embeddings, torch.Tensor): + raise TypeError( + f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}." + ) + + if not isinstance(num_inference_steps, int): + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) + + # 2. Encode caption + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + image_embeddings.size(0) * num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + ) + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) + + # 3. Determine latent shape of latents + latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) + latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale) + latent_features_shape = (image_embeddings.size(0) * num_images_per_prompt, 4, latent_height, latent_width) + + # 4. Prepare and set timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) + + # 6. Run denoising loop + self._num_timesteps = len(timesteps[:-1]) + for i, t in enumerate(self.progress_bar(timesteps[:-1])): + ratio = t.expand(latents.size(0)).to(dtype) + # 7. Denoise latents + predicted_latents = self.decoder( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, + effnet=effnet, + clip=text_encoder_hidden_states, + ) + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) + predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) + + # 9. Renoise latents to next timestep + latents = self.scheduler.step( + model_output=predicted_latents, + timestep=ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeddings = callback_outputs.pop("image_embeddings", image_embeddings) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + # 10. Scale and decode the image latents with vq-vae + latents = self.vqgan.config.scale_factor * latents + images = self.vqgan.decode(latents).sample.clamp(0, 1) + if output_type == "np": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() + elif output_type == "pil": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() + images = self.numpy_to_pil(images) + else: + images = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return images + return ImagePipelineOutput(images) diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py new file mode 100755 index 0000000..7819c8c --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -0,0 +1,306 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import deprecate, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from .modeling_paella_vq_model import PaellaVQModel +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt +from .modeling_wuerstchen_prior import WuerstchenPrior +from .pipeline_wuerstchen import WuerstchenDecoderPipeline +from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusions import WuerstchenCombinedPipeline + + >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to( + ... "cuda" + ... ) + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> images = pipe(prompt=prompt) + ``` +""" + + +class WuerstchenCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Wuerstchen + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The decoder tokenizer to be used for text inputs. + text_encoder (`CLIPTextModel`): + The decoder text encoder to be used for text inputs. + decoder (`WuerstchenDiffNeXt`): + The decoder model to be used for decoder image generation pipeline. + scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for decoder image generation pipeline. + vqgan (`PaellaVQModel`): + The VQGAN model to be used for decoder image generation pipeline. + prior_tokenizer (`CLIPTokenizer`): + The prior tokenizer to be used for text inputs. + prior_text_encoder (`CLIPTextModel`): + The prior text encoder to be used for text inputs. + prior_prior (`WuerstchenPrior`): + The prior model to be used for prior pipeline. + prior_scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for prior pipeline. + """ + + _load_connected_pipes = True + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: WuerstchenDiffNeXt, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModel, + prior_prior: WuerstchenPrior, + prior_scheduler: DDPMWuerstchenScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + prior_prior=prior_prior, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + ) + self.prior_pipe = WuerstchenPriorPipeline( + prior=prior_prior, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + ) + self.decoder_pipe = WuerstchenDecoderPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + prior_num_inference_steps: int = 60, + prior_timesteps: Optional[List[float]] = None, + prior_guidance_scale: float = 4.0, + num_inference_steps: int = 12, + decoder_timesteps: Optional[List[float]] = None, + decoder_guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation for the prior and decoder. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked + to the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60): + The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. For more specific timestep spacing, you can pass customized + `prior_timesteps` + num_inference_steps (`int`, *optional*, defaults to 12): + The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. For more specific timestep spacing, you can pass customized + `timesteps` + prior_timesteps (`List[float]`, *optional*): + Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced + `prior_num_inference_steps` timesteps are used. Must be in descending order. + decoder_timesteps (`List[float]`, *optional*): + Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced + `num_inference_steps` timesteps are used. Must be in descending order. + decoder_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + prior_kwargs = {} + if kwargs.get("prior_callback", None) is not None: + prior_kwargs["callback"] = kwargs.pop("prior_callback") + deprecate( + "prior_callback", + "1.0.0", + "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + if kwargs.get("prior_callback_steps", None) is not None: + deprecate( + "prior_callback_steps", + "1.0.0", + "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps") + + prior_outputs = self.prior_pipe( + prompt=prompt if prompt_embeds is None else None, + height=height, + width=width, + num_inference_steps=prior_num_inference_steps, + timesteps=prior_timesteps, + guidance_scale=prior_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + **prior_kwargs, + ) + image_embeddings = prior_outputs[0] + + outputs = self.decoder_pipe( + image_embeddings=image_embeddings, + prompt=prompt if prompt is not None else "", + num_inference_steps=num_inference_steps, + timesteps=decoder_timesteps, + guidance_scale=decoder_guidance_scale, + negative_prompt=negative_prompt, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + **kwargs, + ) + + return outputs diff --git a/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py new file mode 100755 index 0000000..92223ce --- /dev/null +++ b/diffusers/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -0,0 +1,517 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import ceil +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import BaseOutput, deprecate, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .modeling_wuerstchen_prior import WuerstchenPrior + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPriorPipeline + + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + ``` +""" + + +@dataclass +class WuerstchenPriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeddings (`torch.Tensor` or `np.ndarray`) + Prior image embeddings for text prompt + + """ + + image_embeddings: Union[torch.Tensor, np.ndarray] + + +class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + """ + Pipeline for generating image prior for Wuerstchen. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + prior ([`Prior`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_mean ('float', *optional*, defaults to 42.0): + Mean value for latent diffusers. + latent_std ('float', *optional*, defaults to 1.0): + Standard value for latent diffusers. + resolution_multiple ('float', *optional*, defaults to 42.67): + Default resolution for multiple images generated. + """ + + unet_name = "prior" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "text_encoder->prior" + _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] + _lora_loadable_modules = ["prior", "text_encoder"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prior: WuerstchenPrior, + scheduler: DDPMWuerstchenScheduler, + latent_mean: float = 42.0, + latent_std: float = 1.0, + resolution_multiple: float = 42.67, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + prior=prior, + scheduler=scheduler, + ) + self.register_to_config( + latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device) + ) + prompt_embeds = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + # done duplicates + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + num_inference_steps, + do_classifier_free_guidance, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if not isinstance(num_inference_steps, int): + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 60, + timesteps: List[float] = None, + guidance_scale: float = 8.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pt", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 60): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` [`~pipelines.WuerstchenPriorPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated image embeddings. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # 0. Define commonly used variables + device = self._execution_device + self._guidance_scale = guidance_scale + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 1. Check inputs. Raise error if not correct + if prompt is not None and not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if self.do_classifier_free_guidance: + if negative_prompt is not None and not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) + + self.check_inputs( + prompt, + negative_prompt, + num_inference_steps, + self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 2. Encode caption + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + + # 3. Determine latent shape of image embeddings + dtype = text_encoder_hidden_states.dtype + latent_height = ceil(height / self.config.resolution_multiple) + latent_width = ceil(width / self.config.resolution_multiple) + num_channels = self.prior.config.c_in + effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) + + # 4. Prepare and set timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) + + # 6. Run denoising loop + self._num_timesteps = len(timesteps[:-1]) + for i, t in enumerate(self.progress_bar(timesteps[:-1])): + ratio = t.expand(latents.size(0)).to(dtype) + + # 7. Denoise image embeddings + predicted_image_embedding = self.prior( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, + c=text_encoder_hidden_states, + ) + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale + ) + + # 9. Renoise latents to next timestep + latents = self.scheduler.step( + model_output=predicted_image_embedding, + timestep=ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 10. Denormalize the latents + latents = latents * self.config.latent_mean - self.config.latent_std + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "np": + latents = latents.cpu().float().numpy() + + if not return_dict: + return (latents,) + + return WuerstchenPriorPipelineOutput(latents) diff --git a/diffusers/src/diffusers/py.typed b/diffusers/src/diffusers/py.typed new file mode 100755 index 0000000..e69de29 diff --git a/diffusers/src/diffusers/schedulers/README.md b/diffusers/src/diffusers/schedulers/README.md new file mode 100755 index 0000000..31ad277 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/README.md @@ -0,0 +1,3 @@ +# Schedulers + +For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers/overview). \ No newline at end of file diff --git a/diffusers/src/diffusers/schedulers/__init__.py b/diffusers/src/diffusers/schedulers/__init__.py new file mode 100755 index 0000000..bb90885 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/__init__.py @@ -0,0 +1,221 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_scipy_available, + is_torch_available, + is_torchsde_available, +) + + +_dummy_modules = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) + +else: + _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] + _import_structure["scheduling_amused"] = ["AmusedScheduler"] + _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] + _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] + _import_structure["scheduling_ddim"] = ["DDIMScheduler"] + _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"] + _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] + _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"] + _import_structure["scheduling_ddpm"] = ["DDPMScheduler"] + _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] + _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] + _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] + _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] + _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] + _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"] + _import_structure["scheduling_edm_dpmsolver_multistep"] = ["EDMDPMSolverMultistepScheduler"] + _import_structure["scheduling_edm_euler"] = ["EDMEulerScheduler"] + _import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"] + _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"] + _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] + _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] + _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] + _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] + _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] + _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"] + _import_structure["scheduling_lcm"] = ["LCMScheduler"] + _import_structure["scheduling_pndm"] = ["PNDMScheduler"] + _import_structure["scheduling_repaint"] = ["RePaintScheduler"] + _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] + _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] + _import_structure["scheduling_tcd"] = ["TCDScheduler"] + _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] + _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] + _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] + _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_objects # noqa F403 + + _dummy_modules.update(get_objects_from_module(dummy_flax_objects)) + +else: + _import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"] + _import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"] + _import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"] + _import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"] + _import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"] + _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] + _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] + _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"] + _import_structure["scheduling_utils_flax"] = [ + "FlaxKarrasDiffusionSchedulers", + "FlaxSchedulerMixin", + "FlaxSchedulerOutput", + "broadcast_to_shape_from_left", + ] + + +try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_scipy_objects # noqa F403 + + _dummy_modules.update(get_objects_from_module(dummy_torch_and_scipy_objects)) + +else: + _import_structure["scheduling_lms_discrete"] = ["LMSDiscreteScheduler"] + +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_torchsde_objects # noqa F403 + + _dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects)) + +else: + _import_structure["scheduling_cosine_dpmsolver_multistep"] = ["CosineDPMSolverMultistepScheduler"] + _import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from ..utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_scipy_available, + is_torch_available, + is_torchsde_available, + ) + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler + from .scheduling_amused import AmusedScheduler + from .scheduling_consistency_decoder import ConsistencyDecoderScheduler + from .scheduling_consistency_models import CMStochasticIterativeScheduler + from .scheduling_ddim import DDIMScheduler + from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler + from .scheduling_ddim_inverse import DDIMInverseScheduler + from .scheduling_ddim_parallel import DDIMParallelScheduler + from .scheduling_ddpm import DDPMScheduler + from .scheduling_ddpm_parallel import DDPMParallelScheduler + from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler + from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler + from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler + from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler + from .scheduling_edm_dpmsolver_multistep import EDMDPMSolverMultistepScheduler + from .scheduling_edm_euler import EDMEulerScheduler + from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler + from .scheduling_euler_discrete import EulerDiscreteScheduler + from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler + from .scheduling_heun_discrete import HeunDiscreteScheduler + from .scheduling_ipndm import IPNDMScheduler + from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler + from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler + from .scheduling_lcm import LCMScheduler + from .scheduling_pndm import PNDMScheduler + from .scheduling_repaint import RePaintScheduler + from .scheduling_sasolver import SASolverScheduler + from .scheduling_sde_ve import ScoreSdeVeScheduler + from .scheduling_tcd import TCDScheduler + from .scheduling_unclip import UnCLIPScheduler + from .scheduling_unipc_multistep import UniPCMultistepScheduler + from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin + from .scheduling_vq_diffusion import VQDiffusionScheduler + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 + else: + from .scheduling_ddim_flax import FlaxDDIMScheduler + from .scheduling_ddpm_flax import FlaxDDPMScheduler + from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler + from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler + from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler + from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler + from .scheduling_pndm_flax import FlaxPNDMScheduler + from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_utils_flax import ( + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, + ) + + try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 + else: + from .scheduling_lms_discrete import LMSDiscreteScheduler + + try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 + else: + from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler + from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + for name, value in _dummy_modules.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/schedulers/deprecated/__init__.py b/diffusers/src/diffusers/schedulers/deprecated/__init__.py new file mode 100755 index 0000000..479cf9b --- /dev/null +++ b/diffusers/src/diffusers/schedulers/deprecated/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] + _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_pt_objects import * # noqa F403 + else: + from .scheduling_karras_ve import KarrasVeScheduler + from .scheduling_sde_vp import ScoreSdeVpScheduler + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py b/diffusers/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py new file mode 100755 index 0000000..f5f9bd2 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py @@ -0,0 +1,243 @@ +# Copyright 2024 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.torch_utils import randn_tensor +from ..scheduling_utils import SchedulerMixin + + +@dataclass +class KarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Derivative of predicted original image sample (x_0). + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + derivative: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +class KarrasVeScheduler(SchedulerMixin, ConfigMixin): + """ + A stochastic scheduler tailored to variance-expanding models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + + + For more details on the parameters, see [Appendix E](https://arxiv.org/abs/2206.00364). The grid search values used + to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of the paper. + + + + Args: + sigma_min (`float`, defaults to 0.02): + The minimum noise magnitude. + sigma_max (`float`, defaults to 100): + The maximum noise magnitude. + s_noise (`float`, defaults to 1.007): + The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, + 1.011]. + s_churn (`float`, defaults to 80): + The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. + s_min (`float`, defaults to 0.05): + The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10]. + s_max (`float`, defaults to 50): + The end value of the sigma range to add noise. A reasonable range is [0.2, 80]. + """ + + order = 2 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + # setable values + self.num_inference_steps: int = None + self.timesteps: np.IntTensor = None + self.schedule: torch.Tensor = None # sigma(t_i) + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) + schedule = [ + ( + self.config.sigma_max**2 + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) + for i in self.timesteps + ] + self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) + + def add_noise_to_input( + self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None + ) -> Tuple[torch.Tensor, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a + higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`. + + Args: + sample (`torch.Tensor`): + The input sample. + sigma (`float`): + generator (`torch.Generator`, *optional*): + A random number generator. + """ + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + model_output: torch.Tensor, + sigma_hat: float, + sigma_prev: float, + sample_hat: torch.Tensor, + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + sigma_hat (`float`): + sigma_prev (`float`): + sample_hat (`torch.Tensor`): + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) + + def step_correct( + self, + model_output: torch.Tensor, + sigma_hat: float, + sigma_prev: float, + sample_hat: torch.Tensor, + sample_prev: torch.Tensor, + derivative: torch.Tensor, + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Corrects the predicted sample based on the `model_output` of the network. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.Tensor`): TODO + sample_prev (`torch.Tensor`): TODO + derivative (`torch.Tensor`): TODO + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/diffusers/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py b/diffusers/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py new file mode 100755 index 0000000..09b02ca --- /dev/null +++ b/diffusers/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py @@ -0,0 +1,109 @@ +# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +import math +from typing import Union + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.torch_utils import randn_tensor +from ..scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + """ + `ScoreSdeVpScheduler` is a variance preserving stochastic differential equation (SDE) scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 2000): + The number of diffusion steps to train the model. + beta_min (`int`, defaults to 0.1): + beta_max (`int`, defaults to 20): + sampling_eps (`int`, defaults to 1e-3): + The end value of sampling where timesteps decrease progressively from 1 to epsilon. + """ + + order = 1 + + @register_to_config + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): + """ + Sets the continuous timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) + + def step_pred(self, score, x, t, generator=None): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + score (): + x (): + t (): + generator (`torch.Generator`, *optional*): + A random number generator. + """ + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # TODO(Patrick) better comments + non-PyTorch + # postprocess model score + log_mean_coeff = -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + std = std.flatten() + while len(std.shape) < len(score.shape): + std = std.unsqueeze(-1) + score = -score / std + + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) + beta_t = beta_t.flatten() + while len(beta_t.shape) < len(x.shape): + beta_t = beta_t.unsqueeze(-1) + drift = -0.5 * beta_t * x + + diffusion = torch.sqrt(beta_t) + drift = drift - diffusion**2 * score + x_mean = x + drift * dt + + # add noise + noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) + x = x_mean + diffusion * math.sqrt(-dt) * noise + + return x, x_mean + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_amused.py b/diffusers/src/diffusers/schedulers/scheduling_amused.py new file mode 100755 index 0000000..238b8d8 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_amused.py @@ -0,0 +1,162 @@ +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +def gumbel_noise(t, generator=None): + device = generator.device if generator is not None else t.device + noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device) + return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20)) + + +def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): + confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator) + sorted_confidence = torch.sort(confidence, dim=-1).values + cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) + masking = confidence < cut_off + return masking + + +@dataclass +class AmusedSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: torch.Tensor = None + + +class AmusedScheduler(SchedulerMixin, ConfigMixin): + order = 1 + + temperatures: torch.Tensor + + @register_to_config + def __init__( + self, + mask_token_id: int, + masking_schedule: str = "cosine", + ): + self.temperatures = None + self.timesteps = None + + def set_timesteps( + self, + num_inference_steps: int, + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + device: Union[str, torch.device] = None, + ): + self.timesteps = torch.arange(num_inference_steps, device=device).flip(0) + + if isinstance(temperature, (tuple, list)): + self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device) + else: + self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device) + + def step( + self, + model_output: torch.Tensor, + timestep: torch.long, + sample: torch.LongTensor, + starting_mask_ratio: int = 1, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[AmusedSchedulerOutput, Tuple]: + two_dim_input = sample.ndim == 3 and model_output.ndim == 4 + + if two_dim_input: + batch_size, codebook_size, height, width = model_output.shape + sample = sample.reshape(batch_size, height * width) + model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1) + + unknown_map = sample == self.config.mask_token_id + + probs = model_output.softmax(dim=-1) + + device = probs.device + probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU + if probs_.device.type == "cpu" and probs_.dtype != torch.float32: + probs_ = probs_.float() # multinomial is not implemented for cpu half precision + probs_ = probs_.reshape(-1, probs.size(-1)) + pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device) + pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1]) + pred_original_sample = torch.where(unknown_map, pred_original_sample, sample) + + if timestep == 0: + prev_sample = pred_original_sample + else: + seq_len = sample.shape[1] + step_idx = (self.timesteps == timestep).nonzero() + ratio = (step_idx + 1) / len(self.timesteps) + + if self.config.masking_schedule == "cosine": + mask_ratio = torch.cos(ratio * math.pi / 2) + elif self.config.masking_schedule == "linear": + mask_ratio = 1 - ratio + else: + raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") + + mask_ratio = starting_mask_ratio * mask_ratio + + mask_len = (seq_len * mask_ratio).floor() + # do not mask more than amount previously masked + mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + # mask at least one + mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len) + + selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0] + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + + masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator) + + # Masks tokens with lower confidence. + prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample) + + if two_dim_input: + prev_sample = prev_sample.reshape(batch_size, height, width) + pred_original_sample = pred_original_sample.reshape(batch_size, height, width) + + if not return_dict: + return (prev_sample, pred_original_sample) + + return AmusedSchedulerOutput(prev_sample, pred_original_sample) + + def add_noise(self, sample, timesteps, generator=None): + step_idx = (self.timesteps == timesteps).nonzero() + ratio = (step_idx + 1) / len(self.timesteps) + + if self.config.masking_schedule == "cosine": + mask_ratio = torch.cos(ratio * math.pi / 2) + elif self.config.masking_schedule == "linear": + mask_ratio = 1 - ratio + else: + raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") + + mask_indices = ( + torch.rand( + sample.shape, device=generator.device if generator is not None else sample.device, generator=generator + ).to(sample.device) + < mask_ratio + ) + + masked_sample = sample.clone() + + masked_sample[mask_indices] = self.config.mask_token_id + + return masked_sample diff --git a/diffusers/src/diffusers/schedulers/scheduling_consistency_decoder.py b/diffusers/src/diffusers/schedulers/scheduling_consistency_decoder.py new file mode 100755 index 0000000..d7af018 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_consistency_decoder.py @@ -0,0 +1,180 @@ +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +@dataclass +class ConsistencyDecoderSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin): + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1024, + sigma_data: float = 0.5, + ): + betas = betas_for_alpha_bar(num_train_timesteps) + + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + + sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) + + sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) + + self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2) + self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 + self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + ): + if num_inference_steps != 2: + raise ValueError("Currently more than 2 inference steps are not supported.") + + self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device) + self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device) + self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device) + self.c_skip = self.c_skip.to(device) + self.c_out = self.c_out.to(device) + self.c_in = self.c_in.to(device) + + @property + def init_noise_sigma(self): + return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]] + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample * self.c_in[timestep] + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`float`): + The current timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a + [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise + a tuple is returned where the first element is the sample tensor. + """ + x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample + + timestep_idx = torch.where(self.timesteps == timestep)[0] + + if timestep_idx == len(self.timesteps) - 1: + prev_sample = x_0 + else: + noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device) + prev_sample = ( + self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0 + + self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise + ) + + if not return_dict: + return (prev_sample,) + + return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample) diff --git a/diffusers/src/diffusers/schedulers/scheduling_consistency_models.py b/diffusers/src/diffusers/schedulers/scheduling_consistency_models.py new file mode 100755 index 0000000..6531716 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_consistency_models.py @@ -0,0 +1,446 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class CMStochasticIterativeSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): + """ + Multistep and onestep sampling for consistency models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 40): + The number of diffusion steps to train the model. + sigma_min (`float`, defaults to 0.002): + Minimum noise magnitude in the sigma schedule. Defaults to 0.002 from the original implementation. + sigma_max (`float`, defaults to 80.0): + Maximum noise magnitude in the sigma schedule. Defaults to 80.0 from the original implementation. + sigma_data (`float`, defaults to 0.5): + The standard deviation of the data distribution from the EDM + [paper](https://huggingface.co/papers/2206.00364). Defaults to 0.5 from the original implementation. + s_noise (`float`, defaults to 1.0): + The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, + 1.011]. Defaults to 1.0 from the original implementation. + rho (`float`, defaults to 7.0): + The parameter for calculating the Karras sigma schedule from the EDM + [paper](https://huggingface.co/papers/2206.00364). Defaults to 7.0 from the original implementation. + clip_denoised (`bool`, defaults to `True`): + Whether to clip the denoised outputs to `(-1, 1)`. + timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): + An explicit timestep schedule that can be optionally specified. The timesteps are expected to be in + increasing order. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 40, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + sigma_data: float = 0.5, + s_noise: float = 1.0, + rho: float = 7.0, + clip_denoised: bool = True, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + ramp = np.linspace(0, 1, num_train_timesteps) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + # setable values + self.num_inference_steps = None + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps) + self.custom_timesteps = False + self.is_scale_input_called = False + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`float` or `torch.Tensor`): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + # Get sigma corresponding to timestep + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + + self.is_scale_input_called = True + return sample + + def sigma_to_t(self, sigmas: Union[float, np.ndarray]): + """ + Gets scaled timesteps from the Karras sigmas for input to the consistency model. + + Args: + sigmas (`float` or `np.ndarray`): + A single Karras sigma or an array of Karras sigmas. + + Returns: + `float` or `np.ndarray`: + A scaled input timestep or scaled input timestep array. + """ + if not isinstance(sigmas, np.ndarray): + sigmas = np.array(sigmas, dtype=np.float64) + + timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) + + return timesteps + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, + `num_inference_steps` must be `None`. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") + + # Follow DDPMScheduler custom timesteps logic + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.custom_timesteps = False + + # Map timesteps to Karras sigmas directly for multistep sampling + # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 + num_train_timesteps = self.config.num_train_timesteps + ramp = timesteps[::-1].copy() + ramp = ramp / (num_train_timesteps - 1) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + sigmas = np.concatenate([sigmas, [self.config.sigma_min]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Modified _convert_to_karras implementation that takes in ramp as argument + def _convert_to_karras(self, ramp): + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = self.config.sigma_min + sigma_max: float = self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def get_scalings(self, sigma): + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def get_scalings_for_boundary_condition(self, sigma): + """ + Gets the scalings used in the consistency model parameterization (from Appendix C of the + [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition. + + + + `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`. + + + + Args: + sigma (`torch.Tensor`): + The current sigma in the Karras sigma schedule. + + Returns: + `tuple`: + A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out` + (which weights the consistency model output) is the second element. + """ + sigma_min = self.config.sigma_min + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) + c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`float`): + The current timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a + [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + f" `{self.__class__}.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + sigma_min = self.config.sigma_min + sigma_max = self.config.sigma_max + + if self.step_index is None: + self._init_step_index(timestep) + + # sigma_next corresponds to next_t in original implementation + sigma = self.sigmas[self.step_index] + if self.step_index + 1 < self.config.num_train_timesteps: + sigma_next = self.sigmas[self.step_index + 1] + else: + # Set sigma_next to sigma_min + sigma_next = self.sigmas[-1] + + # Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) + + # 1. Denoise model output using boundary conditions + denoised = c_out * model_output + c_skip * sample + if self.config.clip_denoised: + denoised = denoised.clamp(-1, 1) + + # 2. Sample z ~ N(0, s_noise^2 * I) + # Noise is not used for onestep sampling. + if len(self.timesteps) > 1: + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + else: + noise = torch.zeros_like(model_output) + z = noise * self.config.s_noise + + sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) + + # 3. Return noisy sample + # tau = sigma_hat, eps = sigma_min + prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/diffusers/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py new file mode 100755 index 0000000..ab56650 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -0,0 +1,572 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021). + This scheduler was used in Stable Audio Open [1]. + + [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.3): + Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. + sigma_max (`float`, *optional*, defaults to 500): + Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. + sigma_data (`float`, *optional*, defaults to 1.0): + The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. + sigma_schedule (`str`, *optional*, defaults to `exponential`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. + prediction_type (`str`, defaults to `v_prediction`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.3, + sigma_max: float = 500, + sigma_data: float = 1.0, + sigma_schedule: str = "exponential", + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "v_prediction", + rho: float = 7.0, + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # setable values + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + return sigma.atan() / math.pi * 2 + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.config.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # if a noise sampler is used, reinitialise it + self.noise_sampler = None + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 + sigma_t = sigma + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + sigma = self.sigmas[self.step_index] + x0_pred = self.precondition_outputs(sample, model_output, sigma) + + return x0_pred + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + # sde-dpmsolver++ + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.noise_sampler is None: + seed = None + if generator is not None: + seed = ( + [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() + ) + self.noise_sampler = BrownianTreeNoiseSampler( + model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + ) + noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( + model_output.device + ) + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddim.py b/diffusers/src/diffusers/schedulers/scheduling_ddim.py new file mode 100755 index 0000000..14356ea --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddim.py @@ -0,0 +1,518 @@ +# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/diffusers/src/diffusers/schedulers/scheduling_ddim_cogvideox.py new file mode 100755 index 0000000..ec5c5f3 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -0,0 +1,449 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddim_flax.py b/diffusers/src/diffusers/schedulers/scheduling_ddim_flax.py new file mode 100755 index 0000000..23c71a6 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -0,0 +1,314 @@ +# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, + get_velocity_common, +) + + +@flax.struct.dataclass +class DDIMSchedulerState: + common: CommonSchedulerState + final_alpha_cumprod: jnp.ndarray + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create( + cls, + common: CommonSchedulerState, + final_alpha_cumprod: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + +@dataclass +class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): + state: DDIMSchedulerState + + +class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample between for numerical stability. The clip range is determined by + `clip_sample_range`. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + clip_sample: bool = True, + clip_sample_range: float = 1.0, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + final_alpha_cumprod = ( + jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] + ) + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DDIMSchedulerState.create( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def scale_model_input( + self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def set_timesteps( + self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDIMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDIMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + ) + + def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): + alpha_prod_t = state.common.alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where( + prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + state: DDIMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + eta: float = 0.0, + return_dict: bool = True, + ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class + + Returns: + [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + + alphas_cumprod = state.common.alphas_cumprod + final_alpha_cumprod = state.final_alpha_cumprod + + # 2. compute alphas, betas + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clip( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(state, timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if not return_dict: + return (prev_sample, state) + + return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: DDIMSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def get_velocity( + self, + state: DDIMSchedulerState, + sample: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return get_velocity_common(state.common, sample, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddim_inverse.py b/diffusers/src/diffusers/schedulers/scheduling_ddim_inverse.py new file mode 100755 index 0000000..6c2352f --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -0,0 +1,374 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, deprecate + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMInverseScheduler` is the reverse scheduler of [`DDIMScheduler`]. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to 0, otherwise + it uses the alpha value at step `num_train_timesteps - 1`. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + order = 1 + ignore_for_config = ["kwargs"] + _deprecated_kwargs = ["set_alpha_to_zero"] + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + **kwargs, + ): + if kwargs.get("set_alpha_to_zero", None) is not None: + deprecation_message = ( + "The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead." + ) + deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False) + set_alpha_to_one = kwargs["set_alpha_to_zero"] + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in inverted ddim, we are looking into the next alphas_cumprod + # For the initial step, there is no current alphas_cumprod, and the index is out of bounds + # `set_alpha_to_one` decides whether we set this parameter simply to one + # in this case, self.step() just output the predicted noise + # or whether we use the initial alpha used in training the diffusion model. + self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) + + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or + `tuple`. + + Returns: + [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + + """ + # 1. get previous step value (=t+1) + prev_timestep = timestep + timestep = min( + timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1 + ) + + # 2. compute alphas, betas + # change original implementation to exactly match noise levels for analogous forward process + alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if not return_dict: + return (prev_sample, pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddim_parallel.py b/diffusers/src/diffusers/schedulers/scheduling_ddim_parallel.py new file mode 100755 index 0000000..0cf84b6 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -0,0 +1,643 @@ +# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput +class DDIMParallelSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + _is_ode_scheduler = True + + @register_to_config + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.__init__ + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def _get_variance(self, timestep, prev_timestep=None): + if prev_timestep is None: + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def _batch_get_variance(self, t, prev_t): + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] + alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[DDIMParallelSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def batch_step_no_noise( + self, + model_output: torch.Tensor, + timesteps: List[int], + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + ) -> torch.Tensor: + """ + Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. + Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise + is pre-sampled by the pipeline. + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): direct output from learned diffusion model. + timesteps (`List[int]`): + current discrete timesteps in the diffusion chain. This is now a list of integers. + sample (`torch.Tensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + + Returns: + `torch.Tensor`: sample tensor at previous timestep. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + assert eta == 0.0 + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + t = timesteps + prev_t = t - self.config.num_train_timesteps // self.num_inference_steps + + t = t.view(-1, *([1] * (model_output.ndim - 1))) + prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) + + # 1. compute alphas, betas + self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(model_output.device) + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] + alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._batch_get_variance(t, prev_t).to(model_output.device).view(*alpha_prod_t_prev.shape) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return prev_sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddpm.py b/diffusers/src/diffusers/schedulers/scheduling_ddpm.py new file mode 100755 index 0000000..81a770e --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddpm.py @@ -0,0 +1,560 @@ +# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDPMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDPMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`. + variance_type (`str`, defaults to `"fixed_small"`): + Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, + `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.custom_timesteps = False + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.variance_type = variance_type + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, + `num_inference_steps` must be `None`. + + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + self.custom_timesteps = False + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + prev_t = self.previous_timestep(t) + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + + # we always take the log of variance, so clamp it to ensure it's not 0 + variance = torch.clamp(variance, min=1e-20) + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = variance + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = torch.log(variance) + variance = torch.exp(0.5 * variance) + elif variance_type == "fixed_large": + variance = current_beta_t + elif variance_type == "fixed_large_log": + # Glide max_log + variance = torch.log(current_beta_t) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = torch.log(variance) + max_log = torch.log(current_beta_t) + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[DDPMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + t = timestep + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + device = model_output.device + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + elif self.variance_type == "learned_range": + variance = self._get_variance(t, predicted_variance=predicted_variance) + variance = torch.exp(0.5 * variance) * variance_noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py b/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py new file mode 100755 index 0000000..d06a171 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -0,0 +1,303 @@ +# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, + get_velocity_common, +) + + +@flax.struct.dataclass +class DDPMSchedulerState: + common: CommonSchedulerState + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) + + +@dataclass +class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): + state: DDPMSchedulerState + + +class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DDPMSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def scale_model_input( + self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def set_timesteps( + self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDPMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDPMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + ) + + def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None): + alpha_prod_t = state.common.alphas_cumprod[t] + alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t] + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = jnp.clip(variance, a_min=1e-20) + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = jnp.log(jnp.clip(variance, a_min=1e-20)) + elif variance_type == "fixed_large": + variance = state.common.betas[t] + elif variance_type == "fixed_large_log": + # Glide max_log + variance = jnp.log(state.common.betas[t]) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = variance + max_log = state.common.betas[t] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + state: DDPMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: Optional[jax.Array] = None, + return_dict: bool = True, + ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + key (`jax.Array`): a PRNG key. + return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class + + Returns: + [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if key is None: + key = jax.random.key(0) + + if ( + len(model_output.shape) > 1 + and model_output.shape[1] == sample.shape[1] * 2 + and self.config.variance_type in ["learned", "learned_range"] + ): + model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = state.common.alphas_cumprod[t] + alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " + " for the FlaxDDPMScheduler." + ) + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = jnp.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t + current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + def random_variance(): + split_key = jax.random.split(key, num=1)[0] + noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) + return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise + + variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype)) + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample, state) + + return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) + + def add_noise( + self, + state: DDPMSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def get_velocity( + self, + state: DDPMSchedulerState, + sample: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return get_velocity_common(state.common, sample, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/diffusers/src/diffusers/schedulers/scheduling_ddpm_parallel.py new file mode 100755 index 0000000..5dfcf3c --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -0,0 +1,651 @@ +# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput +class DDPMParallelSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + _is_ode_scheduler = False + + @register_to_config + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.__init__ + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.custom_timesteps = False + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.variance_type = variance_type + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, + `num_inference_steps` must be `None`. + + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + self.custom_timesteps = False + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance + def _get_variance(self, t, predicted_variance=None, variance_type=None): + prev_t = self.previous_timestep(t) + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + + # we always take the log of variance, so clamp it to ensure it's not 0 + variance = torch.clamp(variance, min=1e-20) + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = variance + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = torch.log(variance) + variance = torch.exp(0.5 * variance) + elif variance_type == "fixed_large": + variance = current_beta_t + elif variance_type == "fixed_large_log": + # Glide max_log + variance = torch.log(current_beta_t) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = torch.log(variance) + max_log = torch.log(current_beta_t) + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[DDPMParallelSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + device = model_output.device + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + elif self.variance_type == "learned_range": + variance = self._get_variance(t, predicted_variance=predicted_variance) + variance = torch.exp(0.5 * variance) * variance_noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def batch_step_no_noise( + self, + model_output: torch.Tensor, + timesteps: List[int], + sample: torch.Tensor, + ) -> torch.Tensor: + """ + Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. + Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise + is pre-sampled by the pipeline. + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): direct output from learned diffusion model. + timesteps (`List[int]`): + current discrete timesteps in the diffusion chain. This is now a list of integers. + sample (`torch.Tensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.Tensor`: sample tensor at previous timestep. + """ + t = timesteps + num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + prev_t = t - self.config.num_train_timesteps // num_inference_steps + + t = t.view(-1, *([1] * (model_output.ndim - 1))) + prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + pass + + # 1. compute alphas, betas + self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] + alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMParallelScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + return pred_prev_sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/diffusers/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/diffusers/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py new file mode 100755 index 0000000..71b5669 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022 Pablo Pernías MIT License +# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DDPMWuerstchenSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + scaler (`float`): .... + s (`float`): .... + """ + + @register_to_config + def __init__( + self, + scaler: float = 1.0, + s: float = 0.008, + ): + self.scaler = scaler + self.s = torch.tensor([s]) + self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + def _alpha_cumprod(self, t, device): + if self.scaler > 1: + t = 1 - (1 - t) ** self.scaler + elif self.scaler < 1: + t = t**self.scaler + alpha_cumprod = torch.cos( + (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5 + ) ** 2 / self._init_alpha_cumprod.to(device) + return alpha_cumprod.clamp(0.0001, 0.9999) + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.Tensor`: scaled input sample + """ + return sample + + def set_timesteps( + self, + num_inference_steps: int = None, + timesteps: Optional[List[int]] = None, + device: Union[str, torch.device] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`Dict[float, int]`): + the number of diffusion steps used when generating samples with a pre-trained model. If passed, then + `timesteps` must be `None`. + device (`str` or `torch.device`, optional): + the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10} + """ + if timesteps is None: + timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device) + if not isinstance(timesteps, torch.Tensor): + timesteps = torch.Tensor(timesteps).to(device) + self.timesteps = timesteps + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class + + Returns: + [`DDPMWuerstchenSchedulerOutput`] or `tuple`: [`DDPMWuerstchenSchedulerOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + dtype = model_output.dtype + device = model_output.device + t = timestep + + prev_t = self.previous_timestep(t) + + alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]]) + alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) + alpha = alpha_cumprod / alpha_cumprod_prev + + mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) + + std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) + std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise + pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) + + if not return_dict: + return (pred.to(dtype),) + + return DDPMWuerstchenSchedulerOutput(prev_sample=pred.to(dtype)) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + device = original_samples.device + dtype = original_samples.dtype + alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view( + timesteps.size(0), *[1 for _ in original_samples.shape[1:]] + ) + noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise + return noisy_samples.to(dtype=dtype) + + def __len__(self): + return self.config.num_train_timesteps + + def previous_timestep(self, timestep): + index = (self.timesteps - timestep[0]).abs().argmin().item() + prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0]) + return prev_t diff --git a/diffusers/src/diffusers/schedulers/scheduling_deis_multistep.py b/diffusers/src/diffusers/schedulers/scheduling_deis_multistep.py new file mode 100755 index 0000000..11073ce --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -0,0 +1,790 @@ +# Copyright 2024 FLAIR Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs). + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + algorithm_type (`str`, defaults to `deis`): + The algorithm type for the solver. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "deis", + solver_type: str = "logrho", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DEIS + if algorithm_type not in ["deis"]: + if algorithm_type in ["dpmsolver", "dpmsolver++"]: + self.register_to_config(algorithm_type="deis") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["logrho"]: + if solver_type in ["midpoint", "heun", "bh1", "bh2"]: + self.register_to_config(solver_type="logrho") + else: + raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DEIS algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DEISMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + if self.config.algorithm_type == "deis": + return (sample - alpha_t * x0_pred) / sigma_t + else: + raise NotImplementedError("only support log-rho multistep deis now") + + def deis_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DEIS (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "deis": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + else: + raise NotImplementedError("only support log-rho multistep deis now") + return x_t + + def multistep_deis_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DEIS. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 + + if self.config.algorithm_type == "deis": + + def ind_fn(t, b, c): + # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}] + return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c)) + + coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1) + coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0) + + x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1) + return x_t + else: + raise NotImplementedError("only support log-rho multistep deis now") + + def multistep_deis_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DEIS. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + rho_t, rho_s0, rho_s1, rho_s2 = ( + sigma_t / alpha_t, + sigma_s0 / alpha_s0, + sigma_s1 / alpha_s1, + sigma_s2 / alpha_s2, + ) + + if self.config.algorithm_type == "deis": + + def ind_fn(t, b, c, d): + # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}] + numerator = t * ( + np.log(c) * (np.log(d) - np.log(t) + 1) + - np.log(d) * np.log(t) + + np.log(d) + + np.log(t) ** 2 + - 2 * np.log(t) + + 2 + ) + denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d)) + return numerator / denominator + + coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2) + coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0) + coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1) + + x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2) + + return x_t + else: + raise NotImplementedError("only support log-rho multistep deis now") + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DEIS. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + lower_order_final = ( + (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.deis_first_order_update(model_output, sample=sample) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample) + else: + prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/diffusers/src/diffusers/schedulers/scheduling_dpm_cogvideox.py new file mode 100755 index 0000000..1688972 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -0,0 +1,523 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): + lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log() + lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log() + h = lamb_next - lamb + + if alpha_prod_t_back is not None: + lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): + mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5 + + if alpha_prod_t_back is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def step( + self, + model_output: torch.Tensor, + old_pred_original_sample: torch.Tensor, + timestep: int, + timestep_back: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = False, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)) + mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5 + + noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise + + if old_pred_original_sample is None or prev_timestep < 0: + # Save a network evaluation if all noise levels are 0 or on the first step + return prev_sample, pred_original_sample + else: + denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample + noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise + + prev_sample = x_advanced + + if not return_dict: + return (prev_sample, pred_original_sample) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + def compute_pred_original_sample( + self, + model_output: torch.Tensor, + sample: torch.Tensor, + timestep: int + ) -> torch.Tensor: + """ + Compute the predicted original sample x_0 from the model output and the current noisy sample. + + Args: + model_output (`torch.Tensor`): The predicted noise from the model. + sample (`torch.Tensor`): The current noisy sample. + timestep (`int`): The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: The predicted original sample x_0. + """ + alpha_prod_t = self.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + return pred_original_sample diff --git a/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py new file mode 100755 index 0000000..4472a06 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -0,0 +1,1059 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + if timesteps is not None and self.config.use_lu_lambdas: + raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) + sigmas = np.exp(lambdas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py new file mode 100755 index 0000000..3f48066 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -0,0 +1,643 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, +) + + +@flax.struct.dataclass +class DPMSolverMultistepSchedulerState: + common: CommonSchedulerState + alpha_t: jnp.ndarray + sigma_t: jnp.ndarray + lambda_t: jnp.ndarray + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + # running values + model_outputs: Optional[jnp.ndarray] = None + lower_order_nums: Optional[jnp.int32] = None + prev_timestep: Optional[jnp.int32] = None + cur_sample: Optional[jnp.ndarray] = None + + @classmethod + def create( + cls, + common: CommonSchedulerState, + alpha_t: jnp.ndarray, + sigma_t: jnp.ndarray, + lambda_t: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + +@dataclass +class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): + state: DPMSolverMultistepSchedulerState + + +class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in + https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided + sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + timestep_spacing: str = "linspace", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # Currently we only support VP-type noise schedule + alpha_t = jnp.sqrt(common.alphas_cumprod) + sigma_t = jnp.sqrt(1 - common.alphas_cumprod) + lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) + + # settings for DPM-Solver + if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]: + raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}") + if self.config.solver_type not in ["midpoint", "heun"]: + raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DPMSolverMultistepSchedulerState.create( + common=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def set_timesteps( + self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple + ) -> DPMSolverMultistepSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. + """ + last_timestep = self.config.num_train_timesteps + if self.config.timestep_spacing == "linspace": + timesteps = ( + jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + # initial running values + + model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) + lower_order_nums = jnp.int32(0) + prev_timestep = jnp.int32(-1) + cur_sample = jnp.zeros(shape, dtype=self.dtype) + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + model_outputs=model_outputs, + lower_order_nums=lower_order_nums, + prev_timestep=prev_timestep, + cur_sample=cur_sample, + ) + + def convert_model_output( + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the converted model output. + """ + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.prediction_type == "epsilon": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + dynamic_max_val = jnp.percentile( + jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) + ) + dynamic_max_val = jnp.maximum( + dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) + ) + x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val + return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) + + def dpm_solver_first_order_update( + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0 = prev_timestep, timestep + m0 = model_output + lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0] + alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 + return x_t + + def multistep_dpm_solver_second_order_update( + self, + state: DPMSolverMultistepSchedulerState, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] + alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + state: DPMSolverMultistepSchedulerState, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + state.lambda_t[t], + state.lambda_t[s0], + state.lambda_t[s1], + state.lambda_t[s2], + ) + alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process + from the learned model outputs (most often the predicted noise). + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class + + Returns: + [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1]) + + model_output = self.convert_model_output(state, model_output, timestep, sample) + + model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) + model_outputs_new = model_outputs_new.at[-1].set(model_output) + state = state.replace( + model_outputs=model_outputs_new, + prev_timestep=prev_timestep, + cur_sample=sample, + ) + + def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + return self.dpm_solver_first_order_update( + state, + state.model_outputs[-1], + state.timesteps[step_index], + state.prev_timestep, + state.cur_sample, + ) + + def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]]) + return self.multistep_dpm_solver_second_order_update( + state, + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array( + [ + state.timesteps[step_index - 2], + state.timesteps[step_index - 1], + state.timesteps[step_index], + ] + ) + return self.multistep_dpm_solver_third_order_update( + state, + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + step_2_output = step_2(state) + step_3_output = step_3(state) + + if self.config.solver_order == 2: + return step_2_output + elif self.config.lower_order_final and len(state.timesteps) < 15: + return jax.lax.select( + state.lower_order_nums < 2, + step_2_output, + jax.lax.select( + step_index == len(state.timesteps) - 2, + step_2_output, + step_3_output, + ), + ) + else: + return jax.lax.select( + state.lower_order_nums < 2, + step_2_output, + step_3_output, + ) + + step_1_output = step_1(state) + step_23_output = step_23(state) + + if self.config.solver_order == 1: + prev_sample = step_1_output + + elif self.config.lower_order_final and len(state.timesteps) < 15: + prev_sample = jax.lax.select( + state.lower_order_nums < 1, + step_1_output, + jax.lax.select( + step_index == len(state.timesteps) - 1, + step_1_output, + step_23_output, + ), + ) + + else: + prev_sample = jax.lax.select( + state.lower_order_nums < 1, + step_1_output, + step_23_output, + ) + + state = state.replace( + lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) + + def scale_model_input( + self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def add_noise( + self, + state: DPMSolverMultistepSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py new file mode 100755 index 0000000..6628a92 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -0,0 +1,921 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepInverseScheduler` is the reverse scheduler of [`DPMSolverMultistepScheduler`]. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.use_karras_sigmas = use_karras_sigmas + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item() + self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " + "'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = timesteps.copy().astype(np.int64) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_max = ( + (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] + ) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + self._step_index = step_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_sde.py new file mode 100755 index 0000000..7f2dd08 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,574 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torchsde + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [ + torchsde.BrownianInterval( + t0=t0, + t1=t1, + size=w0.shape, + dtype=w0.dtype, + device=w0.device, + entropy=s, + tol=1e-6, + pool_size=24, + halfway_tree=True, + ) + for s in seed + ] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): + """ + DPMSolverSDEScheduler implements the stochastic sampler from the [Elucidating the Design Space of Diffusion-Based + Generative Models](https://huggingface.co/papers/2206.00364) paper. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.00085): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.012): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed is generated. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas + self.noise_sampler = None + self.noise_sampler_seed = noise_sampler_seed + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input( + self, + sample: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma + sample = sample / ((sigma_input**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + timesteps = torch.from_numpy(timesteps) + second_order_timesteps = torch.from_numpy(second_order_timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + timesteps[1::2] = second_order_timesteps + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps.to(device=device) + + # empty first order variables + self.sample = None + self.mid_point_sigma = None + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.noise_sampler = None + + def _second_order_timesteps(self, sigmas, log_sigmas): + def sigma_fn(_t): + return np.exp(-_t) + + def t_fn(_sigma): + return -np.log(_sigma) + + midpoint_ratio = 0.5 + t = t_fn(sigmas) + delta_time = np.diff(t) + t_proposed = t[:-1] + delta_time * midpoint_ratio + sig_proposed = sigma_fn(t_proposed) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) + return timesteps + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.Tensor, np.ndarray], + timestep: Union[float, torch.Tensor], + sample: Union[torch.Tensor, np.ndarray], + return_dict: bool = True, + s_noise: float = 1.0, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor` or `np.ndarray`): + The direct output from learned diffusion model. + timestep (`float` or `torch.Tensor`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor` or `np.ndarray`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + s_noise (`float`, *optional*, defaults to 1.0): + Scaling factor for noise added to the sample. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + + # Create a noise sampler if it hasn't been created yet + if self.noise_sampler is None: + min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() + self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) + + # Define functions to compute sigma and t from each other + def sigma_fn(_t: torch.Tensor) -> torch.Tensor: + return _t.neg().exp() + + def t_fn(_sigma: torch.Tensor) -> torch.Tensor: + return _sigma.log().neg() + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + else: + # 2nd order + sigma = self.sigmas[self.step_index - 1] + sigma_next = self.sigmas[self.step_index] + + # Set the midpoint and step size for the current step + midpoint_ratio = 0.5 + t, t_next = t_fn(sigma), t_fn(sigma_next) + delta_time = t_next - t + t_proposed = t + delta_time * midpoint_ratio + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if sigma_next == 0: + derivative = (sample - pred_original_sample) / sigma + dt = sigma_next - sigma + prev_sample = sample + derivative * dt + else: + if self.state_in_first_order: + t_next = t_proposed + else: + sample = self.sample + + sigma_from = sigma_fn(t) + sigma_to = sigma_fn(t_next) + sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + ancestral_t = t_fn(sigma_down) + prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( + t - ancestral_t + ).expm1() * pred_original_sample + prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up + + if self.state_in_first_order: + # store for 2nd order step + self.sample = sample + self.mid_point_sigma = sigma_fn(t_next) + else: + # free for "first order mode" + self.sample = None + self.mid_point_sigma = None + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py new file mode 100755 index 0000000..1a10fff --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -0,0 +1,1049 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverSinglestepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver` + type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the + `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) + paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided + sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + final_sigmas_type (`str`, *optional*, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + ): + if algorithm_type == "dpmsolver": + deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.sample = None + self.order_list = self.get_order_list(num_train_timesteps) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def get_order_list(self, num_inference_steps: int) -> List[int]: + """ + Computes the solver order at each time step. + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + steps = num_inference_steps + order = self.config.solver_order + if order > 3: + raise ValueError("Order > 3 is not supported by this scheduler") + if self.config.lower_order_final: + if order == 3: + if steps % 3 == 0: + orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] + elif steps % 3 == 1: + orders = [1, 2, 3] * (steps // 3) + [1] + else: + orders = [1, 2, 3] * (steps // 3) + [1, 2] + elif order == 2: + if steps % 2 == 0: + orders = [1, 2] * (steps // 2 - 1) + [1, 1] + else: + orders = [1, 2] * (steps // 2) + [1] + elif order == 1: + orders = [1] * steps + else: + if order == 3: + orders = [1, 2, 3] * (steps // 3) + elif order == 2: + orders = [1, 2] * (steps // 2) + elif order == 1: + orders = [1] * steps + return orders + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is + passed, `num_inference_steps` must be `None`. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.") + + num_inference_steps = num_inference_steps or len(timesteps) + self.num_inference_steps = num_inference_steps + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + self.model_outputs = [None] * self.config.solver_order + self.sample = None + + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + logger.warning( + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`." + ) + self.register_to_config(lower_order_final=True) + + if not self.config.lower_order_final and self.config.final_sigmas_type == "zero": + logger.warning( + " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." + ) + self.register_to_config(lower_order_final=True) + + self.order_list = self.get_order_list(num_inference_steps) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + return x_t + + def singlestep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the + time `timestep_list[-2]`. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): + The current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m1, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s1 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s1 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + return x_t + + def singlestep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the + time `timestep_list[-3]`. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): + The current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m2 + D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) + D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s2) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s2) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + order: int = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the singlestep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): + The current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + order (`int`): + The solver order at this step. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing `order` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if order == 1: + return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise) + elif order == 2: + return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise) + elif order == 3: + return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) + else: + raise ValueError(f"Order must be 1, 2, 3, got {order}") + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the singlestep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type == "sde-dpmsolver++": + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + order = self.order_list[self.step_index] + + # For img2img denoising might start with order>1 which is not possible + # In this case make sure that the first two steps are both order=1 + while self.model_outputs[-order] is None: + order -= 1 + + # For single-step solvers, we use the initial value at each time with order = 1. + if order == 1: + self.sample = sample + + prev_sample = self.singlestep_dpm_solver_update( + self.model_outputs, sample=self.sample, order=order, noise=noise + ) + + # upon completion increase step index by one, noise=noise + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/diffusers/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py new file mode 100755 index 0000000..c49e8e9 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -0,0 +1,707 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + Implements DPMSolverMultistepScheduler in EDM formulation as presented in Karras et al. 2022 [1]. + `EDMDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.002): + Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable + range is [0, 10]. + sigma_max (`float`, *optional*, defaults to 80.0): + Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable + range is [0.2, 80.0]. + sigma_data (`float`, *optional*, defaults to 0.5): + The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. + sigma_schedule (`str`, *optional*, defaults to `karras`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements + the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to + use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + sigma_data: float = 0.5, + sigma_schedule: str = "karras", + num_train_timesteps: int = 1000, + prediction_type: str = "epsilon", + rho: float = 7.0, + solver_order: int = 2, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # setable values + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + c_noise = 0.25 * torch.log(sigma) + + return c_noise + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.config.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 + sigma_t = sigma + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + sigma = self.sigmas[self.step_index] + x0_pred = self.precondition_outputs(sample, model_output, sigma) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type == "sde-dpmsolver++": + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_edm_euler.py b/diffusers/src/diffusers/schedulers/scheduling_edm_euler.py new file mode 100755 index 0000000..4b823c0 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_edm_euler.py @@ -0,0 +1,402 @@ +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete +class EDMEulerSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +class EDMEulerScheduler(SchedulerMixin, ConfigMixin): + """ + Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1]. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.002): + Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable + range is [0, 10]. + sigma_max (`float`, *optional*, defaults to 80.0): + Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable + range is [0.2, 80.0]. + sigma_data (`float`, *optional*, defaults to 0.5): + The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. + sigma_schedule (`str`, *optional*, defaults to `karras`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + rho (`float`, *optional*, defaults to 7.0): + The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + sigma_data: float = 0.5, + sigma_schedule: str = "karras", + num_train_timesteps: int = 1000, + prediction_type: str = "epsilon", + rho: float = 7.0, + ): + if sigma_schedule not in ["karras", "exponential"]: + raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`") + + # setable values + self.num_inference_steps = None + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.is_scale_input_called = False + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + c_noise = 0.25 * torch.log(sigma) + + return c_noise + + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EDMEulerSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EDMEulerScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100755 index 0000000..485e919 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,479 @@ +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete +class EulerAncestralDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) + + prev_sample = prev_sample + noise * sigma_up + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_euler_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_euler_discrete.py new file mode 100755 index 0000000..46e0e6b --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -0,0 +1,672 @@ +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete +class EulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + interpolation_type(`str`, defaults to `"linear"`, *optional*): + The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of + `"linear"` or `"log_linear"`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + interpolation_type: str = "linear", + use_karras_sigmas: Optional[bool] = False, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + timestep_spacing: str = "linspace", + timestep_type: str = "discrete", # can be "discrete" or "continuous" + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0) + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + # setable values + self.num_inference_steps = None + + # TODO: Support the full EDM scalings for all prediction types and timestep types + if timestep_type == "continuous" and prediction_type == "v_prediction": + self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) + else: + self.timesteps = timesteps + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.is_scale_input_called = False + self.use_karras_sigmas = use_karras_sigmas + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() + if self.config.timestep_spacing in ["linspace", "trailing"]: + return max_sigma + + return (max_sigma**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + + self.is_scale_input_called = True + return sample + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. + sigmas (`List[float]`, *optional*): + Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas + will be generated based on the relevant scheduler attributes. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the + custom sigmas schedule. + """ + + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` should be set.") + if num_inference_steps is None and timesteps is None and sigmas is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") + if num_inference_steps is not None and (timesteps is not None or sigmas is not None): + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") + if ( + timesteps is not None + and self.config.timestep_type == "continuous" + and self.config.prediction_type == "v_prediction" + ): + raise ValueError( + "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." + ) + + if num_inference_steps is None: + num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 + self.num_inference_steps = num_inference_steps + + if sigmas is not None: + log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) + sigmas = np.array(sigmas).astype(np.float32) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) + + else: + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.float32) + else: + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 + )[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.config.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.config.interpolation_type == "log_linear": + sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() + else: + raise ValueError( + f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + # TODO: Support the full EDM scalings for all prediction types and timestep types + if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": + self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) + else: + self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) + + self._step_index = None + self._begin_index = None + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # denoised = model_output * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + if ( + isinstance(timesteps, int) + or isinstance(timesteps, torch.IntTensor) + or isinstance(timesteps, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if sample.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timesteps = timesteps.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timesteps = timesteps.to(sample.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + alphas_cumprod = self.alphas_cumprod.to(sample) + sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_euler_discrete_flax.py b/diffusers/src/diffusers/schedulers/scheduling_euler_discrete_flax.py new file mode 100755 index 0000000..55b0c24 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_euler_discrete_flax.py @@ -0,0 +1,265 @@ +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +@flax.struct.dataclass +class EulerDiscreteSchedulerState: + common: CommonSchedulerState + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + sigmas: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create( + cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + ): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + + +@dataclass +class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): + state: EulerDiscreteSchedulerState + + +class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 + + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) + + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + init_noise_sigma = sigmas.max() + else: + init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5 + + return EulerDiscreteSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) + + def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + state (`EulerDiscreteSchedulerState`): + the `FlaxEulerDiscreteScheduler` state data class instance. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + timestep (`int`): + current discrete timestep in the diffusion chain. + + Returns: + `jnp.ndarray`: scaled input sample + """ + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + sigma = state.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> EulerDiscreteSchedulerState: + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`EulerDiscreteSchedulerState`): + the `FlaxEulerDiscreteScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if self.config.timestep_spacing == "linspace": + timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // num_inference_steps + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += 1 + else: + raise ValueError( + f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}" + ) + + sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) + + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + init_noise_sigma = sigmas.max() + else: + init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5 + + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + num_inference_steps=num_inference_steps, + init_noise_sigma=init_noise_sigma, + ) + + def step( + self, + state: EulerDiscreteSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`EulerDiscreteSchedulerState`): + the `FlaxEulerDiscreteScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class + + Returns: + [`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + sigma = state.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + # dt = sigma_down - sigma + dt = state.sigmas[step_index + 1] - sigma + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample, state) + + return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: EulerDiscreteSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sigma = state.sigmas[timesteps].flatten() + sigma = broadcast_to_shape_from_left(sigma, noise.shape) + + noisy_samples = original_samples + noise * sigma + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py new file mode 100755 index 0000000..81dae28 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,446 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, is_scipy_available, logging +from .scheduling_utils import SchedulerMixin + + +if is_scipy_available(): + import scipy.stats + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + self.num_inference_steps = num_inference_steps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/diffusers/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py new file mode 100755 index 0000000..d9a3ca2 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py @@ -0,0 +1,321 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Heun scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + self.timesteps = timesteps.to(device=device) + + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + @property + def state_in_first_order(self): + return self.dt is None + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[self.step_index - 1] + sigma_next = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + if self.state_in_first_order: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma + # 2. convert to an ODE derivative for 1st order + derivative = (sample - denoised) / sigma_hat + # 3. Delta timestep + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + denoised = sample - model_output * sigma_next + # 2. 2nd order / Heun's method + derivative = (sample - denoised) / sigma_next + derivative = 0.5 * (self.prev_derivative + derivative) + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_heun_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_heun_discrete.py new file mode 100755 index 0000000..8d0a4a8 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -0,0 +1,504 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler with Heun steps for discrete beta schedules. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + use_karras_sigmas: Optional[bool] = False, + clip_sample: Optional[bool] = False, + clip_sample_range: float = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") + elif beta_schedule == "exp": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input( + self, + sample: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + num_train_timesteps (`int`, *optional*): + The number of diffusion steps used when training the model. If `None`, the default + `num_train_timesteps` attribute is used. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be + generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` + must be `None`, and `timestep_spacing` attribute will be ignored. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + + num_inference_steps = num_inference_steps or len(timesteps) + self.num_inference_steps = num_inference_steps + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + if timesteps is not None: + timesteps = np.array(timesteps, dtype=np.float32) + else: + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + timesteps = torch.from_numpy(timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + + self.timesteps = timesteps.to(device=device) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + @property + def state_in_first_order(self): + return self.dt is None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: Union[torch.Tensor, np.ndarray], + timestep: Union[float, torch.Tensor], + sample: Union[torch.Tensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[self.step_index - 1] + sigma_next = self.sigmas[self.step_index] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_next + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_next + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 2. 2nd order / Heun's method + derivative = (sample - pred_original_sample) / sigma_next + derivative = (self.prev_derivative + derivative) / 2 + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_ipndm.py b/diffusers/src/diffusers/schedulers/scheduling_ipndm.py new file mode 100755 index 0000000..28f349a --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_ipndm.py @@ -0,0 +1,224 @@ +# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class IPNDMScheduler(SchedulerMixin, ConfigMixin): + """ + A fourth-order Improved Pseudo Linear Multistep scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + """ + + order = 1 + + @register_to_config + def __init__( + self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None + ): + # set `betas`, `alphas`, `timesteps` + self.set_timesteps(num_train_timesteps) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.ets = [] + self._step_index = None + self._begin_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] + steps = torch.cat([steps, torch.tensor([0.0])]) + + if self.config.trained_betas is not None: + self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) + else: + self.betas = torch.sin(steps * math.pi / 2) ** 2 + + self.alphas = (1.0 - self.betas**2) ** 0.5 + + timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] + self.timesteps = timesteps.to(device) + + self.ets = [] + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the linear multistep method. It performs one forward pass multiple times to approximate the solution. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + if self.step_index is None: + self._init_step_index(timestep) + + timestep_index = self.step_index + prev_timestep_index = self.step_index + 1 + + ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] + self.ets.append(ets) + + if len(self.ets) == 1: + ets = self.ets[-1] + elif len(self.ets) == 2: + ets = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): + alpha = self.alphas[timestep_index] + sigma = self.betas[timestep_index] + + next_alpha = self.alphas[prev_timestep_index] + next_sigma = self.betas[prev_timestep_index] + + pred = (sample - sigma * ets) / max(alpha, 1e-8) + prev_sample = next_alpha * pred + ets * next_sigma + + return prev_sample + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py new file mode 100755 index 0000000..338412d --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -0,0 +1,512 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + KDPM2DiscreteScheduler with ancestral sampling is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating + the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.00085): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.012): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input( + self, + sample: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + else: + sigma = self.sigmas_interpol[self.step_index - 1] + + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + + self.log_sigmas = torch.from_numpy(log_sigmas).to(device) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + + # compute up and down sigmas + sigmas_next = sigmas.roll(-1) + sigmas_next[-1] = 0.0 + sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5 + sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5 + sigmas_down[-1] = 0.0 + + # compute interpolated sigmas + sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp() + sigmas_interpol[-2:] = 0.0 + + # set sigmas + self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) + self.sigmas_interpol = torch.cat( + [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] + ) + self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) + self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) + + if str(device).startswith("mps"): + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + timesteps = torch.from_numpy(timesteps).to(device) + + sigmas_interpol = sigmas_interpol.cpu() + log_sigmas = self.log_sigmas.cpu() + timesteps_interpol = np.array( + [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] + ) + + timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) + interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() + + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) + + self.sample = None + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + @property + def state_in_first_order(self): + return self.sample is None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: Union[torch.Tensor, np.ndarray], + timestep: Union[float, torch.Tensor], + sample: Union[torch.Tensor, np.ndarray], + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_interpol = self.sigmas_interpol[self.step_index] + sigma_up = self.sigmas_up[self.step_index] + sigma_down = self.sigmas_down[self.step_index - 1] + else: + # 2nd order / KPDM2's method + sigma = self.sigmas[self.step_index - 1] + sigma_interpol = self.sigmas_interpol[self.step_index - 1] + sigma_up = self.sigmas_up[self.step_index - 1] + sigma_down = self.sigmas_down[self.step_index - 1] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + device = model_output.device + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_interpol - sigma_hat + + # store for 2nd order step + self.sample = sample + self.dt = dt + prev_sample = sample + derivative * dt + else: + # DPM-Solver-2 + # 2. Convert to an ODE derivative for 2nd order + derivative = (sample - pred_original_sample) / sigma_interpol + # 3. delta timestep + dt = sigma_down - sigma_hat + + sample = self.sample + self.sample = None + + prev_sample = sample + derivative * dt + prev_sample = prev_sample + noise * sigma_up + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py new file mode 100755 index 0000000..de66a7b --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -0,0 +1,487 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + KDPM2DiscreteScheduler is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating the Design Space of + Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.00085): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.012): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input( + self, + sample: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + else: + sigma = self.sigmas_interpol[self.step_index] + + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + + self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + + # interpolate sigmas + sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp() + + self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) + self.sigmas_interpol = torch.cat( + [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] + ) + + timesteps = torch.from_numpy(timesteps).to(device) + + # interpolate timesteps + sigmas_interpol = sigmas_interpol.cpu() + log_sigmas = self.log_sigmas.cpu() + timesteps_interpol = np.array( + [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] + ) + timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) + interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() + + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) + + self.sample = None + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def state_in_first_order(self): + return self.sample is None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def step( + self, + model_output: Union[torch.Tensor, np.ndarray], + timestep: Union[float, torch.Tensor], + sample: Union[torch.Tensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[self.step_index] + sigma_interpol = self.sigmas_interpol[self.step_index + 1] + sigma_next = self.sigmas[self.step_index + 1] + else: + # 2nd order / KDPM2's method + sigma = self.sigmas[self.step_index - 1] + sigma_interpol = self.sigmas_interpol[self.step_index] + sigma_next = self.sigmas[self.step_index] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_interpol - sigma_hat + + # store for 2nd order step + self.sample = sample + else: + # DPM-Solver-2 + # 2. Convert to an ODE derivative for 2nd order + derivative = (sample - pred_original_sample) / sigma_interpol + + # 3. delta timestep + dt = sigma_next - sigma_hat + + sample = self.sample + self.sample = None + + # upon completion increase step index by one + self._step_index += 1 + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/diffusers/src/diffusers/schedulers/scheduling_karras_ve_flax.py new file mode 100755 index 0000000..0d387b5 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -0,0 +1,238 @@ +# Copyright 2024 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils_flax import FlaxSchedulerMixin + + +@flax.struct.dataclass +class KarrasVeSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + schedule: Optional[jnp.ndarray] = None # sigma(t_i) + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxKarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Derivative of predicted original image sample (x_0). + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + """ + + prev_sample: jnp.ndarray + derivative: jnp.ndarray + state: KarrasVeSchedulerState + + +class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + """ + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + ): + pass + + def create_state(self): + return KarrasVeSchedulerState.create() + + def set_timesteps( + self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> KarrasVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`KarrasVeSchedulerState`): + the `FlaxKarrasVeScheduler` state data class. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ + timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() + schedule = [ + ( + self.config.sigma_max**2 + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) + for i in timesteps + ] + + return state.replace( + num_inference_steps=num_inference_steps, + schedule=jnp.array(schedule, dtype=jnp.float32), + timesteps=timesteps, + ) + + def add_noise_to_input( + self, + state: KarrasVeSchedulerState, + sample: jnp.ndarray, + sigma: float, + key: jax.Array, + ) -> Tuple[jnp.ndarray, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + key = random.split(key, num=1) + eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.Tensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class + + Returns: + [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion + chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def step_correct( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + sample_prev: jnp.ndarray, + derivative: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.Tensor` or `np.ndarray`): TODO + sample_prev (`torch.Tensor` or `np.ndarray`): TODO + derivative (`torch.Tensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/diffusers/src/diffusers/schedulers/scheduling_lcm.py b/diffusers/src/diffusers/schedulers/scheduling_lcm.py new file mode 100755 index 0000000..f1aa09a --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_lcm.py @@ -0,0 +1,658 @@ +# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LCMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + denoised: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class LCMScheduler(SchedulerMixin, ConfigMixin): + """ + `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config + attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be + accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving + functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + original_inference_steps (`int`, *optional*, defaults to 50): + The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we + will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_scaling (`float`, defaults to 10.0): + The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions + `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation + error at the default of `10.0` is already pretty small). + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.custom_timesteps = False + + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + original_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + strength: int = 1.0, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + original_inference_steps (`int`, *optional*): + The original number of inference steps, which will be used to generate a linearly-spaced timestep + schedule (which is different from the standard `diffusers` implementation). We will then take + `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as + our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep + schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. + """ + # 0. Check inputs + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + # 1. Calculate the LCM original training/distillation timestep schedule. + original_steps = ( + original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps + ) + + if original_steps > self.config.num_train_timesteps: + raise ValueError( + f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + # LCM Timesteps Setting + # The skipping step parameter k from the paper. + k = self.config.num_train_timesteps // original_steps + # LCM Training/Distillation Steps Schedule + # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 + + # 2. Calculate the LCM inference timestep schedule. + if timesteps is not None: + # 2.1 Handle custom timestep schedules. + train_timesteps = set(lcm_origin_timesteps) + non_train_timesteps = [] + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[i] not in train_timesteps: + non_train_timesteps.append(timesteps[i]) + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 + if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1: + logger.warning( + f"The first timestep on the custom timestep schedule is {timesteps[0]}, not" + f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get" + f" unexpected results when using this timestep schedule." + ) + + # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule + if non_train_timesteps: + logger.warning( + f"The custom timestep schedule contains the following timesteps which are not on the original" + f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results" + f" when using this timestep schedule." + ) + + # Raise warning if custom timestep schedule is longer than original_steps + if len(timesteps) > original_steps: + logger.warning( + f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" + f" the length of the timestep schedule used for training: {original_steps}. You may get some" + f" unexpected results when using this timestep schedule." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.num_inference_steps = len(timesteps) + self.custom_timesteps = True + + # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps) + init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps) + t_start = max(self.num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.order :] + # TODO: also reset self.num_inference_steps? + else: + # 2.2 Create the "standard" LCM inference timestep schedule. + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + + if skipping_step < 1: + raise ValueError( + f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." + ) + + self.num_inference_steps = num_inference_steps + + if num_inference_steps > original_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" + f" {original_steps} because the final timestep schedule will be a subset of the" + f" `original_inference_steps`-sized initial timestep schedule." + ) + + # LCM Inference Steps Schedule + lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from lcm_origin_timesteps. + inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = lcm_origin_timesteps[inference_indices] + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) + + self._step_index = None + self._begin_index = None + + def get_scalings_for_boundary_condition_discrete(self, timestep): + self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.config.timestep_scaling + + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[LCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + elif self.config.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.config.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.config.thresholding: + predicted_original_sample = self._threshold_sample(predicted_original_sample) + elif self.config.clip_sample: + predicted_original_sample = predicted_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype + ) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, denoised) + + return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/diffusers/src/diffusers/schedulers/scheduling_lms_discrete.py b/diffusers/src/diffusers/schedulers/scheduling_lms_discrete.py new file mode 100755 index 0000000..9595bb4 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -0,0 +1,477 @@ +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete +class LMSDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + A linear multistep scheduler for discrete beta schedules. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # setable values + self.num_inference_steps = None + self.use_karras_sigmas = use_karras_sigmas + self.set_timesteps(num_train_timesteps, None) + self.derivatives = [] + self.is_scale_input_called = False + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`float` or `torch.Tensor`): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def get_lms_coefficient(self, order, t, current_order): + """ + Compute the linear multistep coefficient. + + Args: + order (): + t (): + current_order (): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + self.derivatives = [] + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + order: int = 4, + return_dict: bool = True, + ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float` or `torch.Tensor`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`, defaults to 4): + The order of the linear multistep method. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if not self.is_scale_input_called: + warnings.warn( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + if len(self.derivatives) > order: + self.derivatives.pop(0) + + # 3. Compute linear multistep coefficients + order = min(self.step_index + 1, order) + lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) + ) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/diffusers/src/diffusers/schedulers/scheduling_lms_discrete_flax.py new file mode 100755 index 0000000..f1169cc --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -0,0 +1,283 @@ +# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +@flax.struct.dataclass +class LMSDiscreteSchedulerState: + common: CommonSchedulerState + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + sigmas: jnp.ndarray + num_inference_steps: Optional[int] = None + + # running values + derivatives: Optional[jnp.ndarray] = None + + @classmethod + def create( + cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + ): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + + +@dataclass +class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): + state: LMSDiscreteSchedulerState + + +class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + init_noise_sigma = sigmas.max() + + return LMSDiscreteSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) + + def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. + + Args: + state (`LMSDiscreteSchedulerState`): + the `FlaxLMSDiscreteScheduler` state data class instance. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + timestep (`int`): + current discrete timestep in the diffusion chain. + + Returns: + `jnp.ndarray`: scaled input sample + """ + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + sigma = state.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): + """ + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps( + self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> LMSDiscreteSchedulerState: + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`LMSDiscreteSchedulerState`): + the `FlaxLMSDiscreteScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + + low_idx = jnp.floor(timesteps).astype(jnp.int32) + high_idx = jnp.ceil(timesteps).astype(jnp.int32) + + frac = jnp.mod(timesteps, 1.0) + + sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) + + timesteps = timesteps.astype(jnp.int32) + + # initial running values + derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) + + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + num_inference_steps=num_inference_steps, + derivatives=derivatives, + ) + + def step( + self, + state: LMSDiscreteSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + order: int = 4, + return_dict: bool = True, + ) -> Union[FlaxLMSSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class + + Returns: + [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + sigma = state.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + state = state.replace(derivatives=jnp.append(state.derivatives, derivative)) + if len(state.derivatives) > order: + state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) + + # 3. Compute linear multistep coefficients + order = min(timestep + 1, order) + lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: LMSDiscreteSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sigma = state.sigmas[timesteps].flatten() + sigma = broadcast_to_shape_from_left(sigma, noise.shape) + + noisy_samples = original_samples + noise * sigma + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_pndm.py b/diffusers/src/diffusers/schedulers/scheduling_pndm.py new file mode 100755 index 0000000..a05e71c --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_pndm.py @@ -0,0 +1,476 @@ +# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class PNDMScheduler(SchedulerMixin, ConfigMixin): + """ + `PNDMScheduler` uses pseudo numerical methods for diffusion models such as the Runge-Kutta and linear multi-step + method. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + skip_prk_steps (`bool`, defaults to `False`): + Allows the scheduler to skip the Runge-Kutta steps defined in the original paper as being required before + PLMS steps. + set_alpha_to_one (`bool`, defaults to `False`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) + or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) + paper). + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = "epsilon", + timestep_spacing: str = "leading", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.cur_model_output = 0 + self.counter = 0 + self.cur_sample = None + self.ets = [] + + # setable values + self.num_inference_steps = None + self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.prk_timesteps = None + self.plms_timesteps = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + self._timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( + np.int64 + ) + self._timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ + ::-1 + ].copy() + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.ets = [] + self.counter = 0 + self.cur_model_output = 0 + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise), and calls [`~PNDMScheduler.step_prk`] + or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + + def step_prk( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the Runge-Kutta method. It performs four forward passes to approximate the solution to the differential + equation. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output += 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = 0 + + # cur_sample should not be `None` + cur_sample = self.cur_sample if self.cur_sample is not None else sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def step_plms( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the linear multistep method. It performs one forward pass multiple times to approximate the solution. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.config.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information." + ) + + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = None + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.config.prediction_type == "v_prediction": + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + elif self.config.prediction_type != "epsilon": + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" + ) + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_pndm_flax.py b/diffusers/src/diffusers/schedulers/scheduling_pndm_flax.py new file mode 100755 index 0000000..3ac3ba5 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -0,0 +1,509 @@ +# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, +) + + +@flax.struct.dataclass +class PNDMSchedulerState: + common: CommonSchedulerState + final_alpha_cumprod: jnp.ndarray + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + prk_timesteps: Optional[jnp.ndarray] = None + plms_timesteps: Optional[jnp.ndarray] = None + + # running values + cur_model_output: Optional[jnp.ndarray] = None + counter: Optional[jnp.int32] = None + cur_sample: Optional[jnp.ndarray] = None + ets: Optional[jnp.ndarray] = None + + @classmethod + def create( + cls, + common: CommonSchedulerState, + final_alpha_cumprod: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + +@dataclass +class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): + state: PNDMSchedulerState + + +class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + pndm_order: int + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 0, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + final_alpha_cumprod = ( + jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] + ) + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return PNDMSchedulerState.create( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`PNDMSchedulerState`): + the `FlaxPNDMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. + """ + + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + + prk_timesteps = jnp.array([], dtype=jnp.int32) + plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] + + else: + prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( + jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), + self.pndm_order, + ) + + prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] + plms_timesteps = _timesteps[:-3][::-1] + + timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) + + # initial running values + + cur_model_output = jnp.zeros(shape, dtype=self.dtype) + counter = jnp.int32(0) + cur_sample = jnp.zeros(shape, dtype=self.dtype) + ets = jnp.zeros((4,) + shape, dtype=self.dtype) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + prk_timesteps=prk_timesteps, + plms_timesteps=plms_timesteps, + cur_model_output=cur_model_output, + counter=counter, + cur_sample=cur_sample, + ets=ets, + ) + + def scale_model_input( + self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def step( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class + + Returns: + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.config.skip_prk_steps: + prev_sample, state = self.step_plms(state, model_output, timestep, sample) + else: + prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) + plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) + + cond = state.counter < len(state.prk_timesteps) + + prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) + + state = state.replace( + cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), + ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), + cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), + counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) + + def step_prk( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class + + Returns: + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = jnp.where( + state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 + ) + prev_timestep = timestep - diff_to_prev + timestep = state.prk_timesteps[state.counter // 4 * 4] + + model_output = jax.lax.select( + (state.counter % 4) != 3, + model_output, # remainder 0, 1, 2 + state.cur_model_output + 1 / 6 * model_output, # remainder 3 + ) + + state = state.replace( + cur_model_output=jax.lax.select_n( + state.counter % 4, + state.cur_model_output + 1 / 6 * model_output, # remainder 0 + state.cur_model_output + 1 / 3 * model_output, # remainder 1 + state.cur_model_output + 1 / 3 * model_output, # remainder 2 + jnp.zeros_like(state.cur_model_output), # remainder 3 + ), + ets=jax.lax.select( + (state.counter % 4) == 0, + state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0 + state.ets, # remainder 1, 2, 3 + ), + cur_sample=jax.lax.select( + (state.counter % 4) == 0, + sample, # remainder 0 + state.cur_sample, # remainder 1, 2, 3 + ), + ) + + cur_sample = state.cur_sample + prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) + state = state.replace(counter=state.counter + 1) + + return (prev_sample, state) + + def step_plms( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class + + Returns: + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before + + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) + + # Reference: + # if state.counter != 1: + # state.ets.append(model_output) + # else: + # prev_timestep = timestep + # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps + + prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) + timestep = jnp.where( + state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep + ) + + # Reference: + # if len(state.ets) == 1 and state.counter == 0: + # model_output = model_output + # state.cur_sample = sample + # elif len(state.ets) == 1 and state.counter == 1: + # model_output = (model_output + state.ets[-1]) / 2 + # sample = state.cur_sample + # state.cur_sample = None + # elif len(state.ets) == 2: + # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 + # elif len(state.ets) == 3: + # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 + # else: + # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) + + state = state.replace( + ets=jax.lax.select( + state.counter != 1, + state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1 + state.ets, # counter 1 + ), + cur_sample=jax.lax.select( + state.counter != 1, + sample, # counter != 1 + state.cur_sample, # counter 1 + ), + ) + + state = state.replace( + cur_model_output=jax.lax.select_n( + jnp.clip(state.counter, 0, 4), + model_output, # counter 0 + (model_output + state.ets[-1]) / 2, # counter 1 + (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2 + (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3 + (1 / 24) + * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4 + ), + ) + + sample = state.cur_sample + model_output = state.cur_model_output + prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) + state = state.replace(counter=state.counter + 1) + + return (prev_sample, state) + + def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = state.common.alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where( + prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.config.prediction_type == "v_prediction": + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + elif self.config.prediction_type != "epsilon": + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" + ) + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + def add_noise( + self, + state: PNDMSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_repaint.py b/diffusers/src/diffusers/schedulers/scheduling_repaint.py new file mode 100755 index 0000000..97665bb --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_repaint.py @@ -0,0 +1,361 @@ +# Copyright 2024 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +@dataclass +class RePaintSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from + the current timestep. `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: torch.Tensor + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class RePaintScheduler(SchedulerMixin, ConfigMixin): + """ + `RePaintScheduler` is a scheduler for DDPM inpainting inside a given mask. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`. + eta (`float`): + The weight of noise for added noise in diffusion step. If its value is between 0.0 and 1.0 it corresponds + to the DDIM scheduler, and if its value is between -0.0 and 1.0 it corresponds to the DDPM scheduler. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample between -1 and 1 for numerical stability. + + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + eta: float = 0.0, + trained_betas: Optional[np.ndarray] = None, + clip_sample: bool = True, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + self.final_alpha_cumprod = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.eta = eta + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps( + self, + num_inference_steps: int, + jump_length: int = 10, + jump_n_sample: int = 10, + device: Union[str, torch.device] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + jump_length (`int`, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump (“j” in + RePaint paper). Take a look at Figure 9 and 10 in the paper. + jump_n_sample (`int`, defaults to 10): + The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9 + and 10 in the paper. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + + """ + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + self.num_inference_steps = num_inference_steps + + timesteps = [] + + jumps = {} + for j in range(0, num_inference_steps - jump_length, jump_length): + jumps[j] = jump_n_sample - 1 + + t = num_inference_steps + while t >= 1: + t = t - 1 + timesteps.append(t) + + if jumps.get(t, 0) > 0: + jumps[t] = jumps[t] - 1 + for _ in range(jump_length): + t = t + 1 + timesteps.append(t) + + timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t): + prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from + # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get + # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add + # variance to pred_sample + # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf + # without eta. + # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + original_image: torch.Tensor, + mask: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[RePaintSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + original_image (`torch.Tensor`): + The original image to inpaint on. + mask (`torch.Tensor`): + The mask where a value of 0.0 indicates which part of the original image to inpaint. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + t = timestep + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we + # substitute formula (7) in the algorithm coming from DDPM paper + # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. + # DDIM schedule gives the same results as DDPM with eta = 1.0 + # Noise is being reused in 7. and 8., but no impact on quality has + # been observed. + + # 5. Add noise + device = model_output.device + noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) + std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 + + variance = 0 + if t > 0 and self.eta > 0: + variance = std_dev_t * noise + + # 6. compute "direction pointing to x_t" of formula (12) + # from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output + + # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance + + # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf + prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise + + # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf + pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part + + if not return_dict: + return ( + pred_prev_sample, + pred_original_sample, + ) + + return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def undo_step(self, sample, timestep, generator=None): + n = self.config.num_train_timesteps // self.num_inference_steps + + for i in range(n): + beta = self.betas[timestep + i] + if sample.device.type == "mps": + # randn does not work reproducibly on mps + noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator) + noise = noise.to(sample.device) + else: + noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + + # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf + sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise + + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_sasolver.py b/diffusers/src/diffusers/schedulers/scheduling_sasolver.py new file mode 100755 index 0000000..50049a5 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_sasolver.py @@ -0,0 +1,1125 @@ +# Copyright 2024 Shuchen Xue, etc. in University of Chinese Academy of Sciences Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2309.05019 +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class SASolverScheduler(SchedulerMixin, ConfigMixin): + """ + `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + predictor_order (`int`, defaults to 2): + The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for + guided sampling, and `predictor_order=3` for unconditional sampling. + corrector_order (`int`, defaults to 2): + The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for + guided sampling, and `corrector_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + tau_func (`Callable`, *optional*): + Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`. + SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample + from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check + https://arxiv.org/abs/2309.05019 + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `data_prediction`): + Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use + `data_prediction` with `solver_order=2` for guided sampling like in Stable Diffusion. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Default = True. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + predictor_order: int = 2, + corrector_order: int = 2, + prediction_type: str = "epsilon", + tau_func: Optional[Callable] = None, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "data_prediction", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if algorithm_type not in ["data_prediction", "noise_prediction"]: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.timestep_list = [None] * max(predictor_order, corrector_order - 1) + self.model_outputs = [None] * max(predictor_order, corrector_order - 1) + + if tau_func is None: + self.tau_func = lambda t: 1 if t >= 200 and t <= 800 else 0 + else: + self.tau_func = tau_func + self.predict_x0 = algorithm_type == "data_prediction" + self.lower_order_nums = 0 + self.last_sample = None + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [ + None, + ] * max(self.config.predictor_order, self.config.corrector_order - 1) + self.lower_order_nums = 0 + self.last_sample = None + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs. + Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is + designed to discretize an integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both + noise prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + # SA-Solver_data_prediction needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["data_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # SA-Solver_noise_prediction needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["noise_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def get_coefficients_exponential_negative(self, order, interval_start, interval_end): + """ + Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + if order == 0: + return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1) + elif order == 1: + return torch.exp(-interval_end) * ( + (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1) + ) + elif order == 2: + return torch.exp(-interval_end) * ( + (interval_start**2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) + - (interval_end**2 + 2 * interval_end + 2) + ) + elif order == 3: + return torch.exp(-interval_end) * ( + (interval_start**3 + 3 * interval_start**2 + 6 * interval_start + 6) + * torch.exp(interval_end - interval_start) + - (interval_end**3 + 3 * interval_end**2 + 6 * interval_end + 6) + ) + + def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau): + """ + Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + # after change of variable(cov) + interval_end_cov = (1 + tau**2) * interval_end + interval_start_cov = (1 + tau**2) * interval_start + + if order == 0: + return ( + torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (1 + tau**2) + ) + elif order == 1: + return ( + torch.exp(interval_end_cov) + * ( + (interval_end_cov - 1) + - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau**2) ** 2) + ) + elif order == 2: + return ( + torch.exp(interval_end_cov) + * ( + (interval_end_cov**2 - 2 * interval_end_cov + 2) + - (interval_start_cov**2 - 2 * interval_start_cov + 2) + * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau**2) ** 3) + ) + elif order == 3: + return ( + torch.exp(interval_end_cov) + * ( + (interval_end_cov**3 - 3 * interval_end_cov**2 + 6 * interval_end_cov - 6) + - (interval_start_cov**3 - 3 * interval_start_cov**2 + 6 * interval_start_cov - 6) + * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau**2) ** 4) + ) + + def lagrange_polynomial_coefficient(self, order, lambda_list): + """ + Calculate the coefficient of lagrange polynomial + """ + + assert order in [0, 1, 2, 3] + assert order == len(lambda_list) - 1 + if order == 0: + return [[1]] + elif order == 1: + return [ + [ + 1 / (lambda_list[0] - lambda_list[1]), + -lambda_list[1] / (lambda_list[0] - lambda_list[1]), + ], + [ + 1 / (lambda_list[1] - lambda_list[0]), + -lambda_list[0] / (lambda_list[1] - lambda_list[0]), + ], + ] + elif order == 2: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) + return [ + [ + 1 / denominator1, + (-lambda_list[1] - lambda_list[2]) / denominator1, + lambda_list[1] * lambda_list[2] / denominator1, + ], + [ + 1 / denominator2, + (-lambda_list[0] - lambda_list[2]) / denominator2, + lambda_list[0] * lambda_list[2] / denominator2, + ], + [ + 1 / denominator3, + (-lambda_list[0] - lambda_list[1]) / denominator3, + lambda_list[0] * lambda_list[1] / denominator3, + ], + ] + elif order == 3: + denominator1 = ( + (lambda_list[0] - lambda_list[1]) + * (lambda_list[0] - lambda_list[2]) + * (lambda_list[0] - lambda_list[3]) + ) + denominator2 = ( + (lambda_list[1] - lambda_list[0]) + * (lambda_list[1] - lambda_list[2]) + * (lambda_list[1] - lambda_list[3]) + ) + denominator3 = ( + (lambda_list[2] - lambda_list[0]) + * (lambda_list[2] - lambda_list[1]) + * (lambda_list[2] - lambda_list[3]) + ) + denominator4 = ( + (lambda_list[3] - lambda_list[0]) + * (lambda_list[3] - lambda_list[1]) + * (lambda_list[3] - lambda_list[2]) + ) + return [ + [ + 1 / denominator1, + (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, + ( + lambda_list[1] * lambda_list[2] + + lambda_list[1] * lambda_list[3] + + lambda_list[2] * lambda_list[3] + ) + / denominator1, + (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1, + ], + [ + 1 / denominator2, + (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, + ( + lambda_list[0] * lambda_list[2] + + lambda_list[0] * lambda_list[3] + + lambda_list[2] * lambda_list[3] + ) + / denominator2, + (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2, + ], + [ + 1 / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + ( + lambda_list[0] * lambda_list[1] + + lambda_list[0] * lambda_list[3] + + lambda_list[1] * lambda_list[3] + ) + / denominator3, + (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3, + ], + [ + 1 / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + ( + lambda_list[0] * lambda_list[1] + + lambda_list[0] * lambda_list[2] + + lambda_list[1] * lambda_list[2] + ) + / denominator4, + (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4, + ], + ] + + def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau): + assert order in [1, 2, 3, 4] + assert order == len(lambda_list), "the length of lambda list must be equal to the order" + coefficients = [] + lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list) + for i in range(order): + coefficient = 0 + for j in range(order): + if self.predict_x0: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive( + order - 1 - j, interval_start, interval_end, tau + ) + else: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative( + order - 1 - j, interval_start, interval_end + ) + coefficients.append(coefficient) + assert len(coefficients) == order, "the length of coefficients does not match the order" + return coefficients + + def stochastic_adams_bashforth_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor, + noise: torch.Tensor, + order: int, + tau: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + One step for the SA-Predictor. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of SA-Predictor at this timestep. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if noise is None: + if len(args) > 2: + noise = args[2] + else: + raise ValueError(" missing `noise` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing `order` as a required keyward argument") + if tau is None: + if len(args) > 4: + tau = args[4] + else: + raise ValueError(" missing `tau` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + gradient_part = torch.zeros_like(sample) + h = lambda_t - lambda_s0 + lambda_list = [] + + for i in range(order): + si = self.step_index - i + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + lambda_list.append(lambda_si) + + gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + + x = sample + + if self.predict_x0: + if ( + order == 2 + ): ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + temp_sigma = self.sigmas[self.step_index - 1] + temp_alpha_s, temp_sigma_s = self._sigma_to_alpha_sigma_t(temp_sigma) + temp_lambda_s = torch.log(temp_alpha_s) - torch.log(temp_sigma_s) + gradient_coefficients[0] += ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * (h**2 / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2)) + / (lambda_s0 - temp_lambda_s) + ) + gradient_coefficients[1] -= ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * (h**2 / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2)) + / (lambda_s0 - temp_lambda_s) + ) + + for i in range(order): + if self.predict_x0: + gradient_part += ( + (1 + tau**2) + * sigma_t + * torch.exp(-(tau**2) * lambda_t) + * gradient_coefficients[i] + * model_output_list[-(i + 1)] + ) + else: + gradient_part += -(1 + tau**2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + def stochastic_adams_moulton_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor, + last_noise: torch.Tensor, + this_sample: torch.Tensor, + order: int, + tau: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + One step for the SA-Corrector. + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The order of SA-Corrector at this step. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyward argument") + if last_noise is None: + if len(args) > 2: + last_noise = args[2] + else: + raise ValueError(" missing`last_noise` as a required keyward argument") + if this_sample is None: + if len(args) > 3: + this_sample = args[3] + else: + raise ValueError(" missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 4: + order = args[4] + else: + raise ValueError(" missing`order` as a required keyward argument") + if tau is None: + if len(args) > 5: + tau = args[5] + else: + raise ValueError(" missing`tau` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + gradient_part = torch.zeros_like(this_sample) + h = lambda_t - lambda_s0 + lambda_list = [] + for i in range(order): + si = self.step_index - i + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + lambda_list.append(lambda_si) + + model_prev_list = model_output_list + [this_model_output] + + gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + + x = last_sample + + if self.predict_x0: + if ( + order == 2 + ): ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + gradient_coefficients[0] += ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * (h / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2 * h)) + ) + gradient_coefficients[1] -= ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * (h / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2 * h)) + ) + + for i in range(order): + if self.predict_x0: + gradient_part += ( + (1 + tau**2) + * sigma_t + * torch.exp(-(tau**2) * lambda_t) + * gradient_coefficients[i] + * model_prev_list[-(i + 1)] + ) + else: + gradient_part += -(1 + tau**2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * last_noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise + + if self.predict_x0: + x_t = torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the SA-Solver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = self.step_index > 0 and self.last_sample is not None + + model_output_convert = self.convert_model_output(model_output, sample=sample) + + if use_corrector: + current_tau = self.tau_func(self.timestep_list[-1]) + sample = self.stochastic_adams_moulton_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + last_noise=self.last_noise, + this_sample=sample, + order=self.this_corrector_order, + tau=current_tau, + ) + + for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + + if self.config.lower_order_final: + this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - self.step_index) + this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - self.step_index + 1) + else: + this_predictor_order = self.config.predictor_order + this_corrector_order = self.config.corrector_order + + self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep + self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep + assert self.this_predictor_order > 0 + assert self.this_corrector_order > 0 + + self.last_sample = sample + self.last_noise = noise + + current_tau = self.tau_func(self.timestep_list[-1]) + prev_sample = self.stochastic_adams_bashforth_update( + model_output=model_output_convert, + sample=sample, + noise=noise, + order=self.this_predictor_order, + tau=current_tau, + ) + + if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1): + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_sde_ve.py b/diffusers/src/diffusers/schedulers/scheduling_sde_ve.py new file mode 100755 index 0000000..cedfbf7 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_sde_ve.py @@ -0,0 +1,301 @@ +# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@dataclass +class SdeVeOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample` over previous timesteps. + """ + + prev_sample: torch.Tensor + prev_sample_mean: torch.Tensor + + +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + """ + `ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + snr (`float`, defaults to 0.15): + A coefficient weighting the step from the `model_output` sample (from the network) to the random noise. + sigma_min (`float`, defaults to 0.01): + The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror + the distribution of the data. + sigma_max (`float`, defaults to 1348.0): + The maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`, defaults to 1e-5): + The end value of sampling where timesteps decrease progressively from 1 to epsilon. + correct_steps (`int`, defaults to 1): + The number of correction steps performed on a produced sample. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + # setable values + self.timesteps = None + + self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps( + self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None + ): + """ + Sets the continuous timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, *optional*): + The final timestep value (overrides value given during scheduler instantiation). + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device) + + def set_sigmas( + self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None + ): + """ + Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight + of the `drift` and `diffusion` components of the sample update. + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + The initial noise scale value (overrides value given during scheduler instantiation). + sigma_max (`float`, optional): + The final noise scale value (overrides value given during scheduler instantiation). + sampling_eps (`float`, optional): + The final timestep value (overrides value given during scheduler instantiation). + + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if self.timesteps is None: + self.set_timesteps(num_inference_steps, sampling_eps) + + self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) + self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) + self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) + + def get_adjacent_sigma(self, timesteps, t): + return torch.where( + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), + ) + + def step_pred( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SdeVeOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple + is returned where the first element is the sample tensor. + + """ + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * torch.ones( + sample.shape[0], device=sample.device + ) # torch.repeat_interleave(timestep, sample.shape[0]) + timesteps = (timestep * (len(self.timesteps) - 1)).long() + + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + + sigma = self.discrete_sigmas[timesteps].to(sample.device) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) + drift = torch.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + diffusion = diffusion.flatten() + while len(diffusion.shape) < len(sample.shape): + diffusion = diffusion.unsqueeze(-1) + drift = drift - diffusion**2 * model_output + + # equation 6: sample noise for the diffusion term of + noise = randn_tensor( + sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype + ) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean) + + return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) + + def step_correct( + self, + model_output: torch.Tensor, + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after + making the prediction for the previous timestep. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple + is returned where the first element is the sample tensor. + + """ + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device) + + # compute step size from the model_output, the noise, and the snr + grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) + # self.repeat_scalar(step_size, sample.shape[0]) + + # compute corrected sample: model_output term and noise term + step_size = step_size.flatten() + while len(step_size.shape) < len(sample.shape): + step_size = step_size.unsqueeze(-1) + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + timesteps = timesteps.to(original_samples.device) + sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] + noise = ( + noise * sigmas[:, None, None, None] + if noise is not None + else torch.randn_like(original_samples) * sigmas[:, None, None, None] + ) + noisy_samples = noise + original_samples + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/diffusers/src/diffusers/schedulers/scheduling_sde_ve_flax.py new file mode 100755 index 0000000..0a8d45d --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -0,0 +1,280 @@ +# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left + + +@flax.struct.dataclass +class ScoreSdeVeSchedulerState: + # setable values + timesteps: Optional[jnp.ndarray] = None + discrete_sigmas: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxSdeVeOutput(FlaxSchedulerOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + state (`ScoreSdeVeSchedulerState`): + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + state: ScoreSdeVeSchedulerState + prev_sample: jnp.ndarray + prev_sample_mean: Optional[jnp.ndarray] = None + + +class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + """ + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + ): + pass + + def create_state(self): + state = ScoreSdeVeSchedulerState.create() + return self.set_sigmas( + state, + self.config.num_train_timesteps, + self.config.sigma_min, + self.config.sigma_max, + self.config.sampling_eps, + ) + + def set_timesteps( + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None + ) -> ScoreSdeVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): + final timestep value (overrides value given at Scheduler instantiation). + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + + timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) + return state.replace(timesteps=timesteps) + + def set_sigmas( + self, + state: ScoreSdeVeSchedulerState, + num_inference_steps: int, + sigma_min: float = None, + sigma_max: float = None, + sampling_eps: float = None, + ) -> ScoreSdeVeSchedulerState: + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): + final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): + final timestep value (overrides value given at Scheduler instantiation). + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if state.timesteps is None: + state = self.set_timesteps(state, num_inference_steps, sampling_eps) + + discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) + sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) + + return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) + + def get_adjacent_sigma(self, state, timesteps, t): + return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) + + def step_pred( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: jax.Array, + return_dict: bool = True, + ) -> Union[FlaxSdeVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * jnp.ones( + sample.shape[0], + ) + timesteps = (timestep * (len(state.timesteps) - 1)).long() + + sigma = state.discrete_sigmas[timesteps] + adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) + drift = jnp.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + diffusion = diffusion.flatten() + diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) + drift = drift - diffusion**2 * model_output + + # equation 6: sample noise for the diffusion term of + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) + + def step_correct( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + key: jax.Array, + return_dict: bool = True, + ) -> Union[FlaxSdeVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. This is often run repeatedly + after making the prediction for the previous timestep. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + + # compute step size from the model_output, the noise, and the snr + grad_norm = jnp.linalg.norm(model_output) + noise_norm = jnp.linalg.norm(noise) + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * jnp.ones(sample.shape[0]) + + # compute corrected sample: model_output term and noise term + step_size = step_size.flatten() + step_size = broadcast_to_shape_from_left(step_size, sample.shape) + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise + + if not return_dict: + return (prev_sample, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_tcd.py b/diffusers/src/diffusers/schedulers/scheduling_tcd.py new file mode 100755 index 0000000..5802244 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_tcd.py @@ -0,0 +1,695 @@ +# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers.scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class TCDSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_noised_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted noised sample `(x_{s})` based on the model output from the current timestep. + """ + + prev_sample: torch.Tensor + pred_noised_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class TCDScheduler(SchedulerMixin, ConfigMixin): + """ + `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency + Distillation`, extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal. + + This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD). + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config + attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be + accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving + functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + original_inference_steps (`int`, *optional*, defaults to 50): + The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we + will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_scaling (`float`, defaults to 10.0): + The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions + `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation + error at the default of `10.0` is already pretty small). + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.custom_timesteps = False + + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + original_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + strength: float = 1.0, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + original_inference_steps (`int`, *optional*): + The original number of inference steps, which will be used to generate a linearly-spaced timestep + schedule (which is different from the standard `diffusers` implementation). We will then take + `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as + our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep + schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. + strength (`float`, *optional*, defaults to 1.0): + Used to determine the number of timesteps used for inference when using img2img, inpaint, etc. + """ + # 0. Check inputs + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + # 1. Calculate the TCD original training/distillation timestep schedule. + original_steps = ( + original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps + ) + + if original_inference_steps is None: + # default option, timesteps align with discrete inference steps + if original_steps > self.config.num_train_timesteps: + raise ValueError( + f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + # TCD Timesteps Setting + # The skipping step parameter k from the paper. + k = self.config.num_train_timesteps // original_steps + # TCD Training/Distillation Steps Schedule + tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 + else: + # customised option, sampled timesteps can be any arbitrary value + tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps * strength)))) + + # 2. Calculate the TCD inference timestep schedule. + if timesteps is not None: + # 2.1 Handle custom timestep schedules. + train_timesteps = set(tcd_origin_timesteps) + non_train_timesteps = [] + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[i] not in train_timesteps: + non_train_timesteps.append(timesteps[i]) + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 + if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1: + logger.warning( + f"The first timestep on the custom timestep schedule is {timesteps[0]}, not" + f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get" + f" unexpected results when using this timestep schedule." + ) + + # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule + if non_train_timesteps: + logger.warning( + f"The custom timestep schedule contains the following timesteps which are not on the original" + f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results" + f" when using this timestep schedule." + ) + + # Raise warning if custom timestep schedule is longer than original_steps + if original_steps is not None: + if len(timesteps) > original_steps: + logger.warning( + f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" + f" the length of the timestep schedule used for training: {original_steps}. You may get some" + f" unexpected results when using this timestep schedule." + ) + else: + if len(timesteps) > self.config.num_train_timesteps: + logger.warning( + f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" + f" the length of the timestep schedule used for training: {self.config.num_train_timesteps}. You may get some" + f" unexpected results when using this timestep schedule." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.num_inference_steps = len(timesteps) + self.custom_timesteps = True + + # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps) + init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps) + t_start = max(self.num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.order :] + # TODO: also reset self.num_inference_steps? + else: + # 2.2 Create the "standard" TCD inference timestep schedule. + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + if original_steps is not None: + skipping_step = len(tcd_origin_timesteps) // num_inference_steps + + if skipping_step < 1: + raise ValueError( + f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." + ) + + self.num_inference_steps = num_inference_steps + + if original_steps is not None: + if num_inference_steps > original_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" + f" {original_steps} because the final timestep schedule will be a subset of the" + f" `original_inference_steps`-sized initial timestep schedule." + ) + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:" + f" {self.config.num_train_timesteps} because the final timestep schedule will be a subset of the" + f" `num_train_timesteps`-sized initial timestep schedule." + ) + + # TCD Inference Steps Schedule + tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from tcd_origin_timesteps. + inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = tcd_origin_timesteps[inference_indices] + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) + + self._step_index = None + self._begin_index = None + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.3, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[TCDSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every + step. When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic + sampling. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.TCDSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + assert 0 <= eta <= 1.0, "gamma must be less than or equal to 1.0" + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = torch.tensor(0) + + timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + alpha_prod_s = self.alphas_cumprod[timestep_s] + beta_prod_s = 1 - alpha_prod_s + + # 3. Compute the predicted noised sample x_s based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + pred_epsilon = model_output + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + elif self.config.prediction_type == "sample": # x-prediction + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + elif self.config.prediction_type == "v_prediction": # v-prediction + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `TCDScheduler`." + ) + + # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step. + # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling. + if eta > 0: + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype + ) + prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + ( + 1 - alpha_prod_t_prev / alpha_prod_s + ).sqrt() * noise + else: + prev_sample = pred_noised_sample + else: + prev_sample = pred_noised_sample + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, pred_noised_sample) + + return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/diffusers/src/diffusers/schedulers/scheduling_unclip.py b/diffusers/src/diffusers/schedulers/scheduling_unclip.py new file mode 100755 index 0000000..6e15802 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_unclip.py @@ -0,0 +1,352 @@ +# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP +class UnCLIPSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class UnCLIPScheduler(SchedulerMixin, ConfigMixin): + """ + NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This + scheduler will be removed and replaced with DDPM. + + This is a modified DDPM Scheduler specifically for the karlo unCLIP model. + + This scheduler has some minor variations in how it calculates the learned range variance and dynamically + re-calculates betas based off the timesteps it is skipping. + + The scheduler also uses a slightly different step ratio when computing timesteps to use for inference. + + See [`~DDPMScheduler`] for more information on DDPM scheduling + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log` + or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical + stability. + clip_sample_range (`float`, default `1.0`): + The range to clip the sample between. See `clip_sample`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) + or `sample` (directly predicting the noisy sample`) + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + variance_type: str = "fixed_small_log", + clip_sample: bool = True, + clip_sample_range: Optional[float] = 1.0, + prediction_type: str = "epsilon", + beta_schedule: str = "squaredcos_cap_v2", + ): + if beta_schedule != "squaredcos_cap_v2": + raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") + + self.betas = betas_for_alpha_bar(num_train_timesteps) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.variance_type = variance_type + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.Tensor`: scaled input sample + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The + different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy + of the results. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None): + if prev_timestep is None: + prev_timestep = t - 1 + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if prev_timestep == t - 1: + beta = self.betas[t] + else: + beta = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = beta_prod_t_prev / beta_prod_t * beta + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small_log": + variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) + elif variance_type == "learned_range": + # NOTE difference with DDPM scheduler + min_log = variance.log() + max_log = beta.log() + + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + prev_timestep: Optional[int] = None, + generator=None, + return_dict: bool = True, + ) -> Union[UnCLIPSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + current instance of sample being created by diffusion process. + prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. + Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + if prev_timestep is None: + prev_timestep = t - 1 + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if prev_timestep == t - 1: + beta = self.betas[t] + alpha = self.alphas[t] + else: + beta = 1 - alpha_prod_t / alpha_prod_t_prev + alpha = 1 - beta + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`" + " for the UnCLIPScheduler." + ) + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t + current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + variance_noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device + ) + + variance = self._get_variance( + t, + predicted_variance=predicted_variance, + prev_timestep=prev_timestep, + ) + + if self.variance_type == "fixed_small_log": + variance = variance + elif self.variance_type == "learned_range": + variance = (0.5 * variance).exp() + else: + raise ValueError( + f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`" + " for the UnCLIPScheduler." + ) + + variance = variance * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples diff --git a/diffusers/src/diffusers/schedulers/scheduling_unipc_multistep.py b/diffusers/src/diffusers/schedulers/scheduling_unipc_multistep.py new file mode 100755 index 0000000..995f85c --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -0,0 +1,954 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/src/diffusers/schedulers/scheduling_utils.py b/diffusers/src/diffusers/schedulers/scheduling_utils.py new file mode 100755 index 0000000..f20224b --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_utils.py @@ -0,0 +1,193 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args + +from ..utils import BaseOutput, PushToHubMixin + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +# NOTE: We make this type an enum because it simplifies usage in docs and prevents +# circular imports when used for `_compatibles` within the schedulers module. +# When it's used as a type in pipelines, it really is a Union because the actual +# scheduler instance is passed in. +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 14 + EDMEulerScheduler = 15 + + +AysSchedules = { + "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], + "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], + "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], + "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], + "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], +} + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the output of a scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +class SchedulerMixin(PushToHubMixin): + """ + Base class for all schedulers. + + [`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving + functionalities. + + [`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to + the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler + class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden + by parent class). + """ + + config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the scheduler + configuration saved with [`~SchedulerMixin.save_pretrained`]. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + """ + config, kwargs, commit_hash = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to a directory so that it can be reloaded using the + [`~SchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes diff --git a/diffusers/src/diffusers/schedulers/scheduling_utils_flax.py b/diffusers/src/diffusers/schedulers/scheduling_utils_flax.py new file mode 100755 index 0000000..ae11baf --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_utils_flax.py @@ -0,0 +1,291 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import math +import os +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from huggingface_hub.utils import validate_hf_hub_args + +from ..utils import BaseOutput, PushToHubMixin + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +# NOTE: We make this type an enum because it simplifies usage in docs and prevents +# circular imports when used for `_compatibles` within the schedulers module. +# When it's used as a type in pipelines, it really is a Union because the actual +# scheduler instance is passed in. +class FlaxKarrasDiffusionSchedulers(Enum): + FlaxDDIMScheduler = 1 + FlaxDDPMScheduler = 2 + FlaxPNDMScheduler = 3 + FlaxLMSDiscreteScheduler = 4 + FlaxDPMSolverMultistepScheduler = 5 + FlaxEulerDiscreteScheduler = 6 + + +@dataclass +class FlaxSchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: jnp.ndarray + + +class FlaxSchedulerMixin(PushToHubMixin): + """ + Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). + """ + + config_name = SCHEDULER_CONFIG_NAME + ignore_for_config = ["dtype"] + _compatibles = [] + has_compatibles = True + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], + e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) + + if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): + state = scheduler.create_state() + + if return_unused_kwargs: + return scheduler, state, unused_kwargs + + return scheduler, state + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~FlaxSchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes + + +def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: + assert len(shape) >= x.ndim + return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=dtype) + + +@flax.struct.dataclass +class CommonSchedulerState: + alphas: jnp.ndarray + betas: jnp.ndarray + alphas_cumprod: jnp.ndarray + + @classmethod + def create(cls, scheduler): + config = scheduler.config + + if config.trained_betas is not None: + betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) + elif config.beta_schedule == "linear": + betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) + elif config.beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + betas = ( + jnp.linspace( + config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype + ) + ** 2 + ) + elif config.beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) + else: + raise NotImplementedError( + f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" + ) + + alphas = 1.0 - betas + + alphas_cumprod = jnp.cumprod(alphas, axis=0) + + return cls( + alphas=alphas, + betas=betas, + alphas_cumprod=alphas_cumprod, + ) + + +def get_sqrt_alpha_prod( + state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray +): + alphas_cumprod = state.alphas_cumprod + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) + + return sqrt_alpha_prod, sqrt_one_minus_alpha_prod + + +def add_noise_common( + state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray +): + sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps) + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + +def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray): + sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps) + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/diffusers/src/diffusers/schedulers/scheduling_vq_diffusion.py b/diffusers/src/diffusers/schedulers/scheduling_vq_diffusion.py new file mode 100755 index 0000000..bd8d255 --- /dev/null +++ b/diffusers/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -0,0 +1,467 @@ +# Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class VQDiffusionSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.LongTensor + + +def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.Tensor: + """ + Convert batch of vector of class indices into batch of log onehot vectors + + Args: + x (`torch.LongTensor` of shape `(batch size, vector length)`): + Batch of class indices + + num_classes (`int`): + number of classes to be used for the onehot vectors + + Returns: + `torch.Tensor` of shape `(batch size, num classes, vector length)`: + Log onehot vectors + """ + x_onehot = F.one_hot(x, num_classes) + x_onehot = x_onehot.permute(0, 2, 1) + log_x = torch.log(x_onehot.float().clamp(min=1e-30)) + return log_x + + +def gumbel_noised(logits: torch.Tensor, generator: Optional[torch.Generator]) -> torch.Tensor: + """ + Apply gumbel noise to `logits` + """ + uniform = torch.rand(logits.shape, device=logits.device, generator=generator) + gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) + noised = gumbel_noise + logits + return noised + + +def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): + """ + Cumulative and non-cumulative alpha schedules. + + See section 4.1. + """ + att = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + + alpha_cum_start + ) + att = np.concatenate(([1], att)) + at = att[1:] / att[:-1] + att = np.concatenate((att[1:], [1])) + return at, att + + +def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): + """ + Cumulative and non-cumulative gamma schedules. + + See section 4.1. + """ + ctt = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + + gamma_cum_start + ) + ctt = np.concatenate(([0], ctt)) + one_minus_ctt = 1 - ctt + one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] + ct = 1 - one_minus_ct + ctt = np.concatenate((ctt[1:], [0])) + return ct, ctt + + +class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + A scheduler for vector quantized diffusion. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_vec_classes (`int`): + The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked + latent pixel. + num_train_timesteps (`int`, defaults to 100): + The number of diffusion steps to train the model. + alpha_cum_start (`float`, defaults to 0.99999): + The starting cumulative alpha value. + alpha_cum_end (`float`, defaults to 0.00009): + The ending cumulative alpha value. + gamma_cum_start (`float`, defaults to 0.00009): + The starting cumulative gamma value. + gamma_cum_end (`float`, defaults to 0.99999): + The ending cumulative gamma value. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_vec_classes: int, + num_train_timesteps: int = 100, + alpha_cum_start: float = 0.99999, + alpha_cum_end: float = 0.000009, + gamma_cum_start: float = 0.000009, + gamma_cum_end: float = 0.99999, + ): + self.num_embed = num_vec_classes + + # By convention, the index for the mask class is the last class index + self.mask_class = self.num_embed - 1 + + at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) + ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) + + num_non_mask_classes = self.num_embed - 1 + bt = (1 - at - ct) / num_non_mask_classes + btt = (1 - att - ctt) / num_non_mask_classes + + at = torch.tensor(at.astype("float64")) + bt = torch.tensor(bt.astype("float64")) + ct = torch.tensor(ct.astype("float64")) + log_at = torch.log(at) + log_bt = torch.log(bt) + log_ct = torch.log(ct) + + att = torch.tensor(att.astype("float64")) + btt = torch.tensor(btt.astype("float64")) + ctt = torch.tensor(ctt.astype("float64")) + log_cumprod_at = torch.log(att) + log_cumprod_bt = torch.log(btt) + log_cumprod_ct = torch.log(ctt) + + self.log_at = log_at.float() + self.log_bt = log_bt.float() + self.log_ct = log_ct.float() + self.log_cumprod_at = log_cumprod_at.float() + self.log_cumprod_bt = log_cumprod_bt.float() + self.log_cumprod_ct = log_cumprod_ct.float() + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps and diffusion process parameters (alpha, beta, gamma) should be moved + to. + """ + self.num_inference_steps = num_inference_steps + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.log_at = self.log_at.to(device) + self.log_bt = self.log_bt.to(device) + self.log_ct = self.log_ct.to(device) + self.log_cumprod_at = self.log_cumprod_at.to(device) + self.log_cumprod_bt = self.log_cumprod_bt.to(device) + self.log_cumprod_ct = self.log_cumprod_ct.to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: torch.long, + sample: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[VQDiffusionSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by the reverse transition distribution. See + [`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer. + + Args: + log_p_x_0: (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + t (`torch.long`): + The timestep that determines which transition matrices are used. + x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t`. + generator (`torch.Generator`, or `None`): + A random number generator for the noise applied to `p(x_{t-1} | x_t)` before it is sampled from. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or + `tuple`. + + Returns: + [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + if timestep == 0: + log_p_x_t_min_1 = model_output + else: + log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) + + log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) + + x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) + + if not return_dict: + return (x_t_min_1,) + + return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) + + def q_posterior(self, log_p_x_0, x_t, t): + """ + Calculates the log probabilities for the predicted classes of the image at timestep `t-1`: + + ``` + p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) + ``` + + Args: + log_p_x_0 (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t`. + t (`torch.Long`): + The timestep that determines which transition matrix is used. + + Returns: + `torch.Tensor` of shape `(batch size, num classes, num latent pixels)`: + The log probabilities for the predicted classes of the image at timestep `t-1`. + """ + log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) + + log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True + ) + + log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False + ) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q = log_p_x_0 - log_q_x_t_given_x_0 + + # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , + # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n + q = q - q_log_sum_exp + + # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # . . . + # . . . + # . . . + # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # c_cumulative_{t-1} ... c_cumulative_{t-1} + q = self.apply_cumulative_transitions(q, t - 1) + + # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n + # . . . + # . . . + # . . . + # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n + # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 + log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp + + # For each column, there are two possible cases. + # + # Where: + # - sum(p_n(x_0))) is summing over all classes for x_0 + # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) + # - C_j is the class transitioning to + # + # 1. x_t is masked i.e. x_t = c_k + # + # Simplifying the expression, the column vector is: + # . + # . + # . + # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) + # . + # . + # . + # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) + # + # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. + # + # For the other rows, we can state the equation as ... + # + # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] + # + # This verifies the other rows. + # + # 2. x_t is not masked + # + # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: + # . + # . + # . + # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # 0 + # + # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. + return log_p_x_t_min_1 + + def log_Q_t_transitioning_to_known_class( + self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.Tensor, cumulative: bool + ): + """ + Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each + latent pixel in `x_t`. + + Args: + t (`torch.Long`): + The timestep that determines which transition matrix is used. + x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t`. + log_onehot_x_t (`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`): + The log one-hot vectors of `x_t`. + cumulative (`bool`): + If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is + `True`, the cumulative transition matrix `0`->`t` is used. + + Returns: + `torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`: + Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability + transition matrix. + + When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be + masked. + + Where: + - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. + - C_0 is a class of a latent pixel embedding + - C_k is the class of the masked latent pixel + + non-cumulative result (omitting logarithms): + ``` + q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) + . . . + . . . + . . . + q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) + ``` + + cumulative result (omitting logarithms): + ``` + q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) + . . . + . . . + . . . + q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) + ``` + """ + if cumulative: + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + else: + a = self.log_at[t] + b = self.log_bt[t] + c = self.log_ct[t] + + if not cumulative: + # The values in the onehot vector can also be used as the logprobs for transitioning + # from masked latent pixels. If we are not calculating the cumulative transitions, + # we need to save these vectors to be re-appended to the final matrix so the values + # aren't overwritten. + # + # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector + # if x_t is not masked + # + # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector + # if x_t is masked + log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) + + # `index_to_log_onehot` will add onehot vectors for masked pixels, + # so the default one hot matrix has one too many rows. See the doc string + # for an explanation of the dimensionality of the returned matrix. + log_onehot_x_t = log_onehot_x_t[:, :-1, :] + + # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. + # + # Don't worry about what values this sets in the columns that mark transitions + # to masked latent pixels. They are overwrote later with the `mask_class_mask`. + # + # Looking at the below logspace formula in non-logspace, each value will evaluate to either + # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column + # or + # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. + # + # See equation 7 for more details. + log_Q_t = (log_onehot_x_t + a).logaddexp(b) + + # The whole column of each masked pixel is `c` + mask_class_mask = x_t == self.mask_class + mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) + log_Q_t[mask_class_mask] = c + + if not cumulative: + log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) + + return log_Q_t + + def apply_cumulative_transitions(self, q, t): + bsz = q.shape[0] + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + + num_latent_pixels = q.shape[2] + c = c.expand(bsz, 1, num_latent_pixels) + + q = (q + a).logaddexp(b) + q = torch.cat((q, c), dim=1) + + return q diff --git a/diffusers/src/diffusers/training_utils.py b/diffusers/src/diffusers/training_utils.py new file mode 100755 index 0000000..26d4a2a --- /dev/null +++ b/diffusers/src/diffusers/training_utils.py @@ -0,0 +1,610 @@ +import contextlib +import copy +import gc +import math +import random +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch + +from .models import UNet2DConditionModel +from .schedulers import SchedulerMixin +from .utils import ( + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + deprecate, + is_peft_available, + is_torch_npu_available, + is_torchvision_available, + is_transformers_available, +) + + +if is_transformers_available(): + import transformers + +if is_peft_available(): + from peft import set_peft_model_state_dict + +if is_torchvision_available(): + from torchvision import transforms + +if is_torch_npu_available(): + import torch_npu # noqa: F401 + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if is_torch_npu_available(): + torch.npu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +def resolve_interpolation_mode(interpolation_type: str): + """ + Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The + full list of supported enums is documented at + https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. + + Args: + interpolation_type (`str`): + A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, + `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes + in torchvision. + + Returns: + `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` + transform. + """ + if not is_torchvision_available(): + raise ImportError( + "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function." + ) + + if interpolation_type == "bilinear": + interpolation_mode = transforms.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = transforms.InterpolationMode.BICUBIC + elif interpolation_type == "box": + interpolation_mode = transforms.InterpolationMode.BOX + elif interpolation_type == "nearest": + interpolation_mode = transforms.InterpolationMode.NEAREST + elif interpolation_type == "nearest_exact": + interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT + elif interpolation_type == "hamming": + interpolation_mode = transforms.InterpolationMode.HAMMING + elif interpolation_type == "lanczos": + interpolation_mode = transforms.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ) + + return interpolation_mode + + +def compute_dream_and_update_latents( + unet: UNet2DConditionModel, + noise_scheduler: SchedulerMixin, + timesteps: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + target: torch.Tensor, + encoder_hidden_states: torch.Tensor, + dream_detail_preservation: float = 1.0, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210. + DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra + forward step without gradients. + + Args: + `unet`: The state unet to use to make a prediction. + `noise_scheduler`: The noise scheduler used to add noise for the given timestep. + `timesteps`: The timesteps for the noise_scheduler to user. + `noise`: A tensor of noise in the shape of noisy_latents. + `noisy_latents`: Previously noise latents from the training loop. + `target`: The ground-truth tensor to predict after eps is removed. + `encoder_hidden_states`: Text embeddings from the text model. + `dream_detail_preservation`: A float value that indicates detail preservation level. + See reference. + + Returns: + `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target. + """ + alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None] + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments. + dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation + + pred = None + with torch.no_grad(): + pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + _noisy_latents, _target = (None, None) + if noise_scheduler.config.prediction_type == "epsilon": + predicted_noise = pred + delta_noise = (noise - predicted_noise).detach() + delta_noise.mul_(dream_lambda) + _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) + _target = target.add(delta_noise) + elif noise_scheduler.config.prediction_type == "v_prediction": + raise NotImplementedError("DREAM has not been implemented for v-prediction") + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + return _noisy_latents, _target + + +def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: + r""" + Returns: + A state dict containing just the LoRA parameters. + """ + lora_state_dict = {} + + for name, module in unet.named_modules(): + if hasattr(module, "set_lora_layer"): + lora_layer = getattr(module, "lora_layer") + if lora_layer is not None: + current_lora_layer_sd = lora_layer.state_dict() + for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): + # The matrix name can either be "down" or "up". + lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param + + return lora_state_dict + + +def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) + + +def _set_state_dict_into_text_encoder( + lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module +): + """ + Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. + + Args: + lora_state_dict: The state dictionary to be set. + prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. + text_encoder: Where the `lora_state_dict` is to be set. + """ + + text_encoder_state_dict = { + f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) + } + text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) + set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def clear_objs_and_retain_memory(objs: List[Any]): + """Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator.""" + if len(objs) >= 1: + for obj in objs: + del obj + + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() + elif is_torch_npu_available(): + torch_npu.empty_cache() + + +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + foreach: bool = False, + model_cls: Optional[Any] = None, + model_config: Dict[str, Any] = None, + **kwargs, + ): + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. + + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility + use_ema_warmup = True + + if kwargs.get("max_value", None) is not None: + deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." + deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) + decay = kwargs["max_value"] + + if kwargs.get("min_value", None) is not None: + deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." + deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) + min_decay = kwargs["min_value"] + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + if kwargs.get("device", None) is not None: + deprecation_message = "The `device` argument is deprecated. Please use `to` instead." + deprecate("device", "1.0.0", deprecation_message, standard_warn=False) + self.to(device=kwargs["device"]) + + self.temp_stored_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = 0 + self.cur_decay_value = None # set in `step()` + self.foreach = foreach + + self.model_cls = model_cls + self.model_config = model_config + + @classmethod + def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": + _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + model = model_cls.from_pretrained(path) + + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) + + ema_model.load_state_dict(ema_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") + + if self.model_config is None: + raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") + + model = self.model_cls.from_config(self.model_config) + state_dict = self.state_dict() + state_dict.pop("shadow_params", None) + + model.register_to_config(**state_dict) + self.copy_to(model.parameters()) + model.save_pretrained(path) + + def get_decay(self, optimization_step: int) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay + one_minus_decay = 1 - decay + + context_manager = contextlib.nullcontext + if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): + import deepspeed + + if self.foreach: + if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) + + with context_manager(): + params_grad = [param for param in parameters if param.requires_grad] + s_params_grad = [ + s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad + ] + + if len(params_grad) < len(parameters): + torch._foreach_copy_( + [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], + [param for param in parameters if not param.requires_grad], + non_blocking=True, + ) + + torch._foreach_sub_( + s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay + ) + + else: + for s_param, param in zip(self.shadow_params, parameters): + if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + + with context_manager(): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + if self.foreach: + torch._foreach_copy_( + [param.data for param in parameters], + [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], + ) + else: + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) + + def pin_memory(self) -> None: + r""" + Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for + offloading EMA params to the host. + """ + + self.shadow_params = [p.pin_memory() for p in self.shadow_params] + + def to(self, device=None, dtype=None, non_blocking=False) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype, non_blocking=non_blocking) + if p.is_floating_point() + else p.to(device=device, non_blocking=non_blocking) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.min_decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + } + + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") + if self.foreach: + torch._foreach_copy_( + [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] + ) + else: + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") diff --git a/diffusers/src/diffusers/utils/__init__.py b/diffusers/src/diffusers/utils/__init__.py new file mode 100755 index 0000000..c7ea2bc --- /dev/null +++ b/diffusers/src/diffusers/utils/__init__.py @@ -0,0 +1,134 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from packaging import version + +from .. import __version__ +from .constants import ( + CONFIG_NAME, + DEPRECATED_REVISION_ARGS, + DIFFUSERS_DYNAMIC_MODULE_NAME, + FLAX_WEIGHTS_NAME, + HF_MODULES_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + MIN_PEFT_VERSION, + ONNX_EXTERNAL_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_FILE_EXTENSION, + SAFETENSORS_WEIGHTS_NAME, + USE_PEFT_BACKEND, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, +) +from .deprecation_utils import deprecate +from .doc_utils import replace_example_docstring +from .dynamic_modules_utils import get_class_from_dynamic_module +from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video +from .hub_utils import ( + PushToHubMixin, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + extract_commit_hash, + http_user_agent, +) +from .import_utils import ( + BACKENDS_MAPPING, + DIFFUSERS_SLOW_IMPORT, + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + USE_JAX, + USE_TF, + USE_TORCH, + DummyObject, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_accelerate_available, + is_accelerate_version, + is_bitsandbytes_available, + is_bs4_available, + is_flax_available, + is_ftfy_available, + is_google_colab, + is_inflect_available, + is_invisible_watermark_available, + is_k_diffusion_available, + is_k_diffusion_version, + is_librosa_available, + is_matplotlib_available, + is_note_seq_available, + is_onnx_available, + is_peft_available, + is_peft_version, + is_safetensors_available, + is_scipy_available, + is_sentencepiece_available, + is_tensorboard_available, + is_timm_available, + is_torch_available, + is_torch_npu_available, + is_torch_version, + is_torch_xla_available, + is_torchsde_available, + is_torchvision_available, + is_transformers_available, + is_transformers_version, + is_unidecode_available, + is_wandb_available, + is_xformers_available, + requires_backends, +) +from .loading_utils import load_image, load_video +from .logging import get_logger +from .outputs import BaseOutput +from .peft_utils import ( + check_peft_version, + delete_adapter_layers, + get_adapter_name, + get_peft_kwargs, + recurse_remove_peft_layers, + scale_lora_layers, + set_adapter_layers, + set_weights_and_activate_adapters, + unscale_lora_layers, +) +from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil +from .state_dict_utils import ( + convert_all_state_dict_to_peft, + convert_state_dict_to_diffusers, + convert_state_dict_to_kohya, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, +) + + +logger = get_logger(__name__) + + +def check_min_version(min_version): + if version.parse(__version__) < version.parse(min_version): + if "dev" in min_version: + error_message = ( + "This example requires a source install from HuggingFace diffusers (see " + "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," + ) + else: + error_message = f"This example requires a minimum version of {min_version}," + error_message += f" but the version found is {__version__}.\n" + raise ImportError(error_message) diff --git a/diffusers/src/diffusers/utils/accelerate_utils.py b/diffusers/src/diffusers/utils/accelerate_utils.py new file mode 100755 index 0000000..99a8b3a --- /dev/null +++ b/diffusers/src/diffusers/utils/accelerate_utils.py @@ -0,0 +1,48 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Accelerate utilities: Utilities related to accelerate +""" + +from packaging import version + +from .import_utils import is_accelerate_available + + +if is_accelerate_available(): + import accelerate + + +def apply_forward_hook(method): + """ + Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful + for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the + appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. + + This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. + + :param method: The method to decorate. This method should be a method of a PyTorch module. + """ + if not is_accelerate_available(): + return method + accelerate_version = version.parse(accelerate.__version__).base_version + if version.parse(accelerate_version) < version.parse("0.17.0"): + return method + + def wrapper(self, *args, **kwargs): + if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): + self._hf_hook.pre_forward(self) + return method(self, *args, **kwargs) + + return wrapper diff --git a/diffusers/src/diffusers/utils/constants.py b/diffusers/src/diffusers/utils/constants.py new file mode 100755 index 0000000..553ac5d --- /dev/null +++ b/diffusers/src/diffusers/utils/constants.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os + +from huggingface_hub.constants import HF_HOME +from packaging import version + +from ..dependency_versions_check import dep_version_check +from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available + + +MIN_PEFT_VERSION = "0.6.0" +MIN_TRANSFORMERS_VERSION = "4.34.0" +_CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES + + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" +WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" +ONNX_WEIGHTS_NAME = "model.onnx" +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json" +SAFETENSORS_FILE_EXTENSION = "safetensors" +ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) +DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] + +# Below should be `True` if the current version of `peft` and `transformers` are compatible with +# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are +# available. +# For PEFT it is has to be greater than or equal to 0.6.0 and for transformers it has to be greater than or equal to 4.34.0. +_required_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version +) >= version.parse(MIN_PEFT_VERSION) +_required_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version +) >= version.parse(MIN_TRANSFORMERS_VERSION) + +USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version + +if USE_PEFT_BACKEND and _CHECK_PEFT: + dep_version_check("peft") diff --git a/diffusers/src/diffusers/utils/deprecation_utils.py b/diffusers/src/diffusers/utils/deprecation_utils.py new file mode 100755 index 0000000..f482ded --- /dev/null +++ b/diffusers/src/diffusers/utils/deprecation_utils.py @@ -0,0 +1,49 @@ +import inspect +import warnings +from typing import Any, Dict, Optional, Union + +from packaging import version + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): + from .. import __version__ + + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError( + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" + ) + + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values diff --git a/diffusers/src/diffusers/utils/doc_utils.py b/diffusers/src/diffusers/utils/doc_utils.py new file mode 100755 index 0000000..fe633e6 --- /dev/null +++ b/diffusers/src/diffusers/utils/doc_utils.py @@ -0,0 +1,39 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Doc utilities: Utilities related to documentation +""" + +import re + + +def replace_example_docstring(example_docstring): + def docstring_decorator(fn): + func_doc = fn.__doc__ + lines = func_doc.split("\n") + i = 0 + while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: + i += 1 + if i < len(lines): + lines[i] = example_docstring + func_doc = "\n".join(lines) + else: + raise ValueError( + f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " + f"current docstring is:\n{func_doc}" + ) + fn.__doc__ = func_doc + return fn + + return docstring_decorator diff --git a/diffusers/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/diffusers/src/diffusers/utils/dummy_flax_and_transformers_objects.py new file mode 100755 index 0000000..5e65e53 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -0,0 +1,77 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) diff --git a/diffusers/src/diffusers/utils/dummy_flax_objects.py b/diffusers/src/diffusers/utils/dummy_flax_objects.py new file mode 100755 index 0000000..5fa8dbc --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_flax_objects.py @@ -0,0 +1,212 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class FlaxControlNetModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxModelMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAutoencoderKL(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDDIMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDDPMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxEulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxKarrasVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxLMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxPNDMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxSchedulerMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) diff --git a/diffusers/src/diffusers/utils/dummy_note_seq_objects.py b/diffusers/src/diffusers/utils/dummy_note_seq_objects.py new file mode 100755 index 0000000..c02d0b0 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_note_seq_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class MidiProcessor(metaclass=DummyObject): + _backends = ["note_seq"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["note_seq"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["note_seq"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["note_seq"]) diff --git a/diffusers/src/diffusers/utils/dummy_onnx_objects.py b/diffusers/src/diffusers/utils/dummy_onnx_objects.py new file mode 100755 index 0000000..bde5f6a --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_onnx_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class OnnxRuntimeModel(metaclass=DummyObject): + _backends = ["onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["onnx"]) diff --git a/diffusers/src/diffusers/utils/dummy_pt_objects.py b/diffusers/src/diffusers/utils/dummy_pt_objects.py new file mode 100755 index 0000000..1ab946c --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_pt_objects.py @@ -0,0 +1,1545 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class AsymmetricAutoencoderKL(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AuraFlowTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKL(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKLCogVideoX(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKLTemporalDecoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderOobleck(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderTiny(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CogVideoXTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ConsistencyDecoderVAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ControlNetXSAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DiTTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class FluxControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class FluxMultiControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class FluxTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HunyuanDiT2DControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HunyuanDiT2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class I2VGenXLUNet(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class Kandinsky3UNet(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LatteTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LuminaNextDiT2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ModelMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class MotionAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class MultiAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PixArtTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PriorTransformer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SD3ControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SD3MultiControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SD3Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SparseControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class StableAudioDiTModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class T2IAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class T5FilmDecoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet1DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet2DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet3DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNetControlNetXSModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNetMotionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNetSpatioTemporalConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UVit2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class VQModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +def get_constant_schedule(*args, **kwargs): + requires_backends(get_constant_schedule, ["torch"]) + + +def get_constant_schedule_with_warmup(*args, **kwargs): + requires_backends(get_constant_schedule_with_warmup, ["torch"]) + + +def get_cosine_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_schedule_with_warmup, ["torch"]) + + +def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) + + +def get_linear_schedule_with_warmup(*args, **kwargs): + requires_backends(get_linear_schedule_with_warmup, ["torch"]) + + +def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): + requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) + + +def get_scheduler(*args, **kwargs): + requires_backends(get_scheduler, ["torch"]) + + +class AudioPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoPipelineForImage2Image(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoPipelineForInpainting(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoPipelineForText2Image(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlipDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlipDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CLIPImageProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ConsistencyModelPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DanceDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDPMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DiTPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ImagePipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KarrasVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LDMSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PNDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class RePaintPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ScoreSdeVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class StableDiffusionMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AmusedScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CMStochasticIterativeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CogVideoXDDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CogVideoXDPMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMInverseScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMParallelScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDPMParallelScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDPMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDPMWuerstchenScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DEISMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DPMSolverSinglestepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EDMDPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EDMEulerScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EulerAncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class FlowMatchEulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class FlowMatchHeunDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HeunDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class IPNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KarrasVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KDPM2DiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LCMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class RePaintScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SASolverScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SchedulerMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class TCDScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UnCLIPScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UniPCMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class VQDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EMAModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) diff --git a/diffusers/src/diffusers/utils/dummy_torch_and_librosa_objects.py b/diffusers/src/diffusers/utils/dummy_torch_and_librosa_objects.py new file mode 100755 index 0000000..2088bc4 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_torch_and_librosa_objects.py @@ -0,0 +1,32 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class AudioDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "librosa"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "librosa"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) + + +class Mel(metaclass=DummyObject): + _backends = ["torch", "librosa"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "librosa"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) diff --git a/diffusers/src/diffusers/utils/dummy_torch_and_scipy_objects.py b/diffusers/src/diffusers/utils/dummy_torch_and_scipy_objects.py new file mode 100755 index 0000000..a1ff258 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_torch_and_scipy_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class LMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch", "scipy"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "scipy"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "scipy"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "scipy"]) diff --git a/diffusers/src/diffusers/utils/dummy_torch_and_torchsde_objects.py b/diffusers/src/diffusers/utils/dummy_torch_and_torchsde_objects.py new file mode 100755 index 0000000..6ff1423 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_torch_and_torchsde_objects.py @@ -0,0 +1,32 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class CosineDPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["torch", "torchsde"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "torchsde"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + +class DPMSolverSDEScheduler(metaclass=DummyObject): + _backends = ["torch", "torchsde"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "torchsde"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) diff --git a/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py new file mode 100755 index 0000000..2ab00c5 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py @@ -0,0 +1,32 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "k_diffusion"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "k_diffusion"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "k_diffusion"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "k_diffusion"]) + + +class StableDiffusionXLKDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "k_diffusion"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "k_diffusion"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "k_diffusion"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "k_diffusion"]) diff --git a/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py new file mode 100755 index 0000000..b7afad8 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -0,0 +1,92 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class StableDiffusionOnnxPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) diff --git a/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py new file mode 100755 index 0000000..349db10 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py @@ -0,0 +1,47 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class KolorsImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + +class KolorsPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + +class KolorsPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "sentencepiece"]) diff --git a/diffusers/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_objects.py new file mode 100755 index 0000000..4f22501 --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -0,0 +1,2207 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AltDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AmusedImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AmusedInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AmusedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimateDiffControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimateDiffPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimateDiffPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimateDiffSDXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AudioLDM2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AudioLDM2ProjectionModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AudioLDM2UNet2DConditionModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AudioLDMPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AuraFlowPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class CLIPImageProjection(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class CogVideoXImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class CogVideoXPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class CogVideoXVideoToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class CycleDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxControlNetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HunyuanDiTControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HunyuanDiTPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HunyuanDiTPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class I2VGenXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFInpaintingPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ImageTextPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Kandinsky3Img2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Kandinsky3Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyImg2ImgCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyInpaintCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyPriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22CombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22ControlnetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22Img2ImgCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22Img2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22InpaintCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22InpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22PriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LatentConsistencyModelPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LattePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LDMTextToImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LEditsPPPipelineStableDiffusion(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LuminaText2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class MarigoldDepthPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class MarigoldNormalsPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class MusicLDMPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class PaintByExamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class PIAPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class PixArtAlphaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class PixArtSigmaPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class PixArtSigmaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SemanticStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ShapEImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ShapEPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableAudioPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableAudioProjectionModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableCascadeCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableCascadeDecoderPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableCascadePriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusion3ControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusion3Img2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusion3InpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusion3PAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusion3Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionAdapterPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetXSPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionDiffEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionGLIGENPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionGLIGENTextImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionLDM3DPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionModelEditingPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPanoramaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionParadigmsPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPipelineSafe(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionSAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLAdapterPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLControlNetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLControlNetPAGImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLControlNetPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLPAGInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableUnCLIPPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableVideoDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class TextToVideoSDPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class TextToVideoZeroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class TextToVideoZeroSDXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UnCLIPImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UnCLIPPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserTextDecoder(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VideoToVideoSDPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VQDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WuerstchenCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WuerstchenDecoderPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WuerstchenPriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) diff --git a/diffusers/src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py b/diffusers/src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py new file mode 100755 index 0000000..fbde04e --- /dev/null +++ b/diffusers/src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class SpectrogramDiffusionPipeline(metaclass=DummyObject): + _backends = ["transformers", "torch", "note_seq"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "torch", "note_seq"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["transformers", "torch", "note_seq"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["transformers", "torch", "note_seq"]) diff --git a/diffusers/src/diffusers/utils/dynamic_modules_utils.py b/diffusers/src/diffusers/utils/dynamic_modules_utils.py new file mode 100755 index 0000000..f0cf953 --- /dev/null +++ b/diffusers/src/diffusers/utils/dynamic_modules_utils.py @@ -0,0 +1,457 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import importlib +import inspect +import json +import os +import re +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union +from urllib import request + +from huggingface_hub import hf_hub_download, model_info +from huggingface_hub.utils import RevisionNotFoundError, validate_hf_hub_args +from packaging import version + +from .. import __version__ +from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror +COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" + + +def get_diffusers_versions(): + url = "https://pypi.org/pypi/diffusers/json" + releases = json.loads(request.urlopen(url).read())["releases"].keys() + return sorted(releases, key=lambda x: version.Version(x)) + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def create_dynamic_module(name: Union[str, os.PathLike]): + """ + Creates a dynamic module in the cache directory for modules. + """ + init_hf_modules() + dynamic_module_path = Path(HF_MODULES_CACHE) / name + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def get_relative_imports(module_file): + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file): + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [str(module_path / m) for m in new_imports] + new_import_files = [f for f in new_import_files if f not in all_relative_imports] + files_to_check = [f"{f}.py" for f in new_import_files] + + no_change = len(new_import_files) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def check_imports(filename): + """ + Check if the current Python environment contains all the libraries that are imported in a file. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import xxx` + imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + + # Unique-ify and test we got them all + imports = list(set(imports)) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError: + missing_packages.append(imp) + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module(class_name, module_path): + """ + Import a module on the cache directory for modules and extract a class from it. + """ + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + + if class_name is None: + return find_pipeline_class(module) + return getattr(module, class_name) + + +def find_pipeline_class(loaded_module): + """ + Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class + inheriting from `DiffusionPipeline`. + """ + from ..pipelines import DiffusionPipeline + + cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) + + pipeline_class = None + for cls_name, cls in cls_members.items(): + if ( + cls_name != DiffusionPipeline.__name__ + and issubclass(cls, DiffusionPipeline) + and cls.__module__.split(".")[0] != "diffusers" + ): + if pipeline_class is not None: + raise ValueError( + f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:" + f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in" + f" {loaded_module}." + ) + pipeline_class = cls + + return pipeline_class + + +@validate_hf_hub_args +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or + [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + Returns: + `str`: The path to the module inside the cache. + """ + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + + if os.path.isfile(module_file_or_url): + resolved_module_file = module_file_or_url + submodule = "local" + elif pretrained_model_name_or_path.count("/") == 0: + available_versions = get_diffusers_versions() + # cut ".dev0" + latest_version = "v" + ".".join(__version__.split(".")[:3]) + + # retrieve github version that matches + if revision is None: + revision = latest_version if latest_version[1:] in available_versions else "main" + logger.info(f"Defaulting to latest_version: {revision}.") + elif revision in available_versions: + revision = f"v{revision}" + elif revision == "main": + revision = revision + else: + raise ValueError( + f"`custom_revision`: {revision} does not exist. Please make sure to choose one of" + f" {', '.join(available_versions + ['main'])}." + ) + + try: + resolved_module_file = hf_hub_download( + repo_id=COMMUNITY_PIPELINES_MIRROR_ID, + repo_type="dataset", + filename=f"{revision}/{pretrained_model_name_or_path}.py", + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + ) + submodule = "git" + module_file = pretrained_model_name_or_path + ".py" + except RevisionNotFoundError as e: + raise EnvironmentError( + f"Revision '{revision}' not found in the community pipelines mirror. Check available revisions on" + " https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main." + " If you don't find the revision you are looking for, please open an issue on https://github.com/huggingface/diffusers/issues." + ) from e + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + else: + try: + # Load from URL or cache if already cached + resolved_module_file = hf_hub_download( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + ) + submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + if submodule == "local" or submodule == "git": + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + shutil.copy(resolved_module_file, submodule_path / module_file) + for module_needed in modules_needed: + if len(module_needed.split(".")) == 2: + module_needed = "/".join(module_needed.split(".")) + module_folder = module_needed.split("/")[0] + if not os.path.exists(submodule_path / module_folder): + os.makedirs(submodule_path / module_folder) + module_needed = f"{module_needed}.py" + shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + else: + # Get the commit hash + # TODO: we will get this info in the etag soon, so retrieve it from there and not here. + commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha + + # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the + # benefit of versioning. + submodule_path = submodule_path / commit_hash + full_submodule = full_submodule + os.path.sep + commit_hash + create_dynamic_module(full_submodule) + + if not (submodule_path / module_file).exists(): + if len(module_file.split("/")) == 2: + module_folder = module_file.split("/")[0] + if not os.path.exists(submodule_path / module_folder): + os.makedirs(submodule_path / module_folder) + shutil.copy(resolved_module_file, submodule_path / module_file) + + # Make sure we also have every file with relative + for module_needed in modules_needed: + if len(module_needed.split(".")) == 2: + module_needed = "/".join(module_needed.split(".")) + if not (submodule_path / module_needed).exists(): + get_cached_module_file( + pretrained_model_name_or_path, + f"{module_needed}.py", + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + ) + return os.path.join(full_submodule, module_file) + + +@validate_hf_hub_args +def get_class_from_dynamic_module( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + class_name: Optional[str] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or + [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + Returns: + `type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + ```""" + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + ) + return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/diffusers/src/diffusers/utils/export_utils.py b/diffusers/src/diffusers/utils/export_utils.py new file mode 100755 index 0000000..0080543 --- /dev/null +++ b/diffusers/src/diffusers/utils/export_utils.py @@ -0,0 +1,184 @@ +import io +import random +import struct +import tempfile +from contextlib import contextmanager +from typing import List, Union + +import numpy as np +import PIL.Image +import PIL.ImageOps + +from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available +from .logging import get_logger + + +global_rng = random.Random() + +logger = get_logger(__name__) + + +@contextmanager +def buffered_writer(raw_f): + f = io.BufferedWriter(raw_f) + yield f + f.flush() + + +def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None, fps: int = 10) -> str: + if output_gif_path is None: + output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name + + image[0].save( + output_gif_path, + save_all=True, + append_images=image[1:], + optimize=False, + duration=1000 // fps, + loop=0, + ) + return output_gif_path + + +def export_to_ply(mesh, output_ply_path: str = None): + """ + Write a PLY file for a mesh. + """ + if output_ply_path is None: + output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name + + coords = mesh.verts.detach().cpu().numpy() + faces = mesh.faces.cpu().numpy() + rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1) + + with buffered_writer(open(output_ply_path, "wb")) as f: + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + if rgb is not None: + f.write(b"property uchar red\n") + f.write(b"property uchar green\n") + f.write(b"property uchar blue\n") + if faces is not None: + f.write(bytes(f"element face {len(faces)}\n", "ascii")) + f.write(b"property list uchar int vertex_index\n") + f.write(b"end_header\n") + + if rgb is not None: + rgb = (rgb * 255.499).round().astype(int) + vertices = [ + (*coord, *rgb) + for coord, rgb in zip( + coords.tolist(), + rgb.tolist(), + ) + ] + format = struct.Struct("<3f3B") + for item in vertices: + f.write(format.pack(*item)) + else: + format = struct.Struct("<3f") + for vertex in coords.tolist(): + f.write(format.pack(*vertex)) + + if faces is not None: + format = struct.Struct(" str: + # TODO: Dhruv. Remove by Diffusers release 0.33.0 + # Added to prevent breaking existing code + if not is_imageio_available(): + logger.warning( + ( + "It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n" + "These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n" + "Support for the OpenCV backend will be deprecated in a future Diffusers version" + ) + ) + return _legacy_export_to_video(video_frames, output_video_path, fps) + + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + ( + "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" + "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" + ) + ) + + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + + with imageio.get_writer(output_video_path, fps=fps) as writer: + for frame in video_frames: + writer.append_data(frame) + + return output_video_path diff --git a/diffusers/src/diffusers/utils/hub_utils.py b/diffusers/src/diffusers/utils/hub_utils.py new file mode 100755 index 0000000..1cdc02e --- /dev/null +++ b/diffusers/src/diffusers/utils/hub_utils.py @@ -0,0 +1,603 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import re +import sys +import tempfile +import traceback +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Union +from uuid import uuid4 + +from huggingface_hub import ( + ModelCard, + ModelCardData, + create_repo, + hf_hub_download, + model_info, + snapshot_download, + upload_folder, +) +from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE +from huggingface_hub.file_download import REGEX_COMMIT_HASH +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + is_jinja_available, + validate_hf_hub_args, +) +from packaging import version +from requests import HTTPError + +from .. import __version__ +from .constants import ( + DEPRECATED_REVISION_ARGS, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, +) +from .import_utils import ( + ENV_VARS_TRUE_VALUES, + _flax_version, + _jax_version, + _onnxruntime_version, + _torch_version, + is_flax_available, + is_onnx_available, + is_torch_available, +) +from .logging import get_logger + + +logger = get_logger(__name__) + +MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" +SESSION_ID = uuid4().hex + + +def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: + """ + Formats a user-agent string with basic info about a request. + """ + ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" + if HF_HUB_DISABLE_TELEMETRY or HF_HUB_OFFLINE: + return ua + "; telemetry/off" + if is_torch_available(): + ua += f"; torch/{_torch_version}" + if is_flax_available(): + ua += f"; jax/{_jax_version}" + ua += f"; flax/{_flax_version}" + if is_onnx_available(): + ua += f"; onnxruntime/{_onnxruntime_version}" + # CI will set this value to True + if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: + ua += "; is_ci/true" + if isinstance(user_agent, dict): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + return ua + + +def load_or_create_model_card( + repo_id_or_path: str = None, + token: Optional[str] = None, + is_pipeline: bool = False, + from_training: bool = False, + model_description: Optional[str] = None, + base_model: str = None, + prompt: Optional[str] = None, + license: Optional[str] = None, + widget: Optional[List[dict]] = None, + inference: Optional[bool] = None, +) -> ModelCard: + """ + Loads or creates a model card. + + Args: + repo_id_or_path (`str`): + The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card. + token (`str`, *optional*): + Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more + details. + is_pipeline (`bool`): + Boolean to indicate if we're adding tag to a [`DiffusionPipeline`]. + from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script. + model_description (`str`, *optional*): Model description to add to the model card. Helpful when using + `load_or_create_model_card` from a training script. + base_model (`str`): Base model identifier (e.g., "stabilityai/stable-diffusion-xl-base-1.0"). Useful + for DreamBooth-like training. + prompt (`str`, *optional*): Prompt used for training. Useful for DreamBooth-like training. + license: (`str`, *optional*): License of the output artifact. Helpful when using + `load_or_create_model_card` from a training script. + widget (`List[dict]`, *optional*): Widget to accompany a gallery template. + inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using + `load_or_create_model_card` from a training script. + """ + if not is_jinja_available(): + raise ValueError( + "Modelcard rendering is based on Jinja templates." + " Please make sure to have `jinja` installed before using `load_or_create_model_card`." + " To install it, please run `pip install Jinja2`." + ) + + try: + # Check if the model card is present on the remote repo + model_card = ModelCard.load(repo_id_or_path, token=token) + except (EntryNotFoundError, RepositoryNotFoundError): + # Otherwise create a model card from template + if from_training: + model_card = ModelCard.from_template( + card_data=ModelCardData( # Card metadata object that will be converted to YAML block + license=license, + library_name="diffusers", + inference=inference, + base_model=base_model, + instance_prompt=prompt, + widget=widget, + ), + template_path=MODEL_CARD_TEMPLATE_PATH, + model_description=model_description, + ) + else: + card_data = ModelCardData() + component = "pipeline" if is_pipeline else "model" + if model_description is None: + model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." + model_card = ModelCard.from_template(card_data, model_description=model_description) + + return model_card + + +def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard: + """Populates the `model_card` with library name and optional tags.""" + if model_card.data.library_name is None: + model_card.data.library_name = "diffusers" + + if tags is not None: + if isinstance(tags, str): + tags = [tags] + if model_card.data.tags is None: + model_card.data.tags = [] + for tag in tags: + model_card.data.tags.append(tag) + + return model_card + + +def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): + """ + Extracts the commit hash from a resolved filename toward a cache file. + """ + if resolved_file is None or commit_hash is not None: + return commit_hash + resolved_file = str(Path(resolved_file).as_posix()) + search = re.search(r"snapshots/([^/]+)/", resolved_file) + if search is None: + return None + commit_hash = search.groups()[0] + return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None + + +# Old default cache path, potentially to be migrated. +# This logic was more or less taken from `transformers`, with the following differences: +# - Diffusers doesn't use custom environment variables to specify the cache path. +# - There is no need to migrate the cache format, just move the files to the new location. +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") + + +def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: + if new_cache_dir is None: + new_cache_dir = HF_HUB_CACHE + if old_cache_dir is None: + old_cache_dir = old_diffusers_cache + + old_cache_dir = Path(old_cache_dir).expanduser() + new_cache_dir = Path(new_cache_dir).expanduser() + for old_blob_path in old_cache_dir.glob("**/blobs/*"): + if old_blob_path.is_file() and not old_blob_path.is_symlink(): + new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) + new_blob_path.parent.mkdir(parents=True, exist_ok=True) + os.replace(old_blob_path, new_blob_path) + try: + os.symlink(new_blob_path, old_blob_path) + except OSError: + logger.warning( + "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." + ) + # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). + + +cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt") +if not os.path.isfile(cache_version_file): + cache_version = 0 +else: + with open(cache_version_file) as f: + try: + cache_version = int(f.read()) + except ValueError: + cache_version = 0 + +if cache_version < 1: + old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 + if old_cache_is_not_empty: + logger.warning( + "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " + "existing cached models. This is a one-time operation, you can interrupt it or run it " + "later by calling `diffusers.utils.hub_utils.move_cache()`." + ) + try: + move_cache() + except Exception as e: + trace = "\n".join(traceback.format_tb(e.__traceback__)) + logger.error( + f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " + "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " + "message and we will do our best to help." + ) + +if cache_version < 1: + try: + os.makedirs(HF_HUB_CACHE, exist_ok=True) + with open(cache_version_file, "w") as f: + f.write("1") + except Exception: + logger.warning( + f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure " + "the directory exists and can be written to." + ) + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + split_index = -2 if weights_name.endswith(".index.json") else -1 + splits = splits[:-split_index] + [variant] + splits[-split_index:] + weights_name = ".".join(splits) + + return weights_name + + +@validate_hf_hub_args +def _get_model_file( + pretrained_model_name_or_path: Union[str, Path], + *, + weights_name: str, + subfolder: Optional[str] = None, + cache_dir: Optional[str] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + local_files_only: bool = False, + token: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + revision: Optional[str] = None, + commit_hash: Optional[str] = None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + # 1. First check if deprecated way of loading from branches is used + if ( + revision in DEPRECATED_REVISION_ARGS + and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) + and version.parse(version.parse(__version__).base_version) >= version.parse("0.22.0") + ): + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=_add_variant(weights_name, revision), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + warnings.warn( + f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + return model_file + except: # noqa: E722 + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", + FutureWarning, + ) + try: + # 2. Load model file as usual + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + return model_file + + except RepositoryNotFoundError as e: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `token` or log in with `huggingface-cli " + "login`." + ) from e + except RevisionNotFoundError as e: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) from e + except EntryNotFoundError as e: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) from e + except HTTPError as e: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}" + ) from e + except ValueError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) from e + except EnvironmentError as e: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) from e + + +# Adapted from +# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976 +# Differences are in parallelization of shard downloads and checking if shards are present. + + +def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames): + shards_path = os.path.join(local_dir, subfolder) + shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] + for shard_file in shard_filenames: + if not os.path.exists(shard_file): + raise ValueError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + +def _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_filename, + cache_dir=None, + proxies=None, + local_files_only=False, + token=None, + user_agent=None, + revision=None, + subfolder="", +): + """ + For a given model: + + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + + For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + original_shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + shards_path = os.path.join(pretrained_model_name_or_path, subfolder) + + # First, let's deal with local folder. + if os.path.isdir(pretrained_model_name_or_path): + _check_if_shards_exist_locally( + pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + return shards_path, sharded_metadata + + # At this stage pretrained_model_name_or_path is a model identifier on the Hub + allow_patterns = original_shard_filenames + if subfolder is not None: + allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] + + ignore_patterns = ["*.json", "*.md"] + if not local_files_only: + # `model_info` call must guarded with the above condition. + model_files_info = model_info(pretrained_model_name_or_path, revision=revision) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + if subfolder is not None: + cached_folder = os.path.join(cached_folder, subfolder) + + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e + + # If `local_files_only=True`, `cached_folder` may not contain all the shard files. + elif local_files_only: + _check_if_shards_exist_locally( + local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames + ) + if subfolder is not None: + cached_folder = os.path.join(cached_folder, subfolder) + + return cached_folder, sharded_metadata + + +class PushToHubMixin: + """ + A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. + """ + + def _upload_folder( + self, + working_dir: Union[str, os.PathLike], + repo_id: str, + token: Optional[str] = None, + commit_message: Optional[str] = None, + create_pr: bool = False, + ): + """ + Uploads all files in `working_dir` to `repo_id`. + """ + if commit_message is None: + if "Model" in self.__class__.__name__: + commit_message = "Upload model" + elif "Scheduler" in self.__class__.__name__: + commit_message = "Upload scheduler" + else: + commit_message = f"Upload {self.__class__.__name__}" + + logger.info(f"Uploading the files of {working_dir} to {repo_id}.") + return upload_folder( + repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr + ) + + def push_to_hub( + self, + repo_id: str, + commit_message: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[str] = None, + create_pr: bool = False, + safe_serialization: bool = True, + variant: Optional[str] = None, + ) -> str: + """ + Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub. + + Parameters: + repo_id (`str`): + The name of the repository you want to push your model, scheduler, or pipeline files to. It should + contain your organization name when pushing to an organization. `repo_id` can also be a path to a local + directory. + commit_message (`str`, *optional*): + Message to commit while pushing. Default to `"Upload {object}"`. + private (`bool`, *optional*): + Whether or not the repository created should be private. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. The token generated when running + `huggingface-cli login` (stored in `~/.huggingface`). + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether or not to convert the model weights to the `safetensors` format. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + + Examples: + + ```python + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet") + + # Push the `unet` to your namespace with the name "my-finetuned-unet". + unet.push_to_hub("my-finetuned-unet") + + # Push the `unet` to an organization with the name "my-finetuned-unet". + unet.push_to_hub("your-org/my-finetuned-unet") + ``` + """ + repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id + + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) + + # Save all files. + save_kwargs = {"safe_serialization": safe_serialization} + if "Scheduler" not in self.__class__.__name__: + save_kwargs.update({"variant": variant}) + + with tempfile.TemporaryDirectory() as tmpdir: + self.save_pretrained(tmpdir, **save_kwargs) + + # Update model card if needed: + model_card.save(os.path.join(tmpdir, "README.md")) + + return self._upload_folder( + tmpdir, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) diff --git a/diffusers/src/diffusers/utils/import_utils.py b/diffusers/src/diffusers/utils/import_utils.py new file mode 100755 index 0000000..34cc5fc --- /dev/null +++ b/diffusers/src/diffusers/utils/import_utils.py @@ -0,0 +1,838 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.util +import operator as op +import os +import sys +from collections import OrderedDict +from itertools import chain +from types import ModuleType +from typing import Any, Union + +from huggingface_hub.utils import is_jinja_available # noqa: F401 +from packaging import version +from packaging.version import Version, parse + +from . import logging + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() +DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper() +DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +_torch_version = "N/A" +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TORCH is set") + _torch_available = False + +_torch_xla_available = importlib.util.find_spec("torch_xla") is not None +if _torch_xla_available: + try: + _torch_xla_version = importlib_metadata.version("torch_xla") + logger.info(f"PyTorch XLA version {_torch_xla_version} available.") + except ImportError: + _torch_xla_available = False + +# check whether torch_npu is available +_torch_npu_available = importlib.util.find_spec("torch_npu") is not None +if _torch_npu_available: + try: + _torch_npu_version = importlib_metadata.version("torch_npu") + logger.info(f"torch_npu version {_torch_npu_version} available.") + except ImportError: + _torch_npu_available = False + +_jax_version = "N/A" +_flax_version = "N/A" +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + +if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: + _safetensors_available = importlib.util.find_spec("safetensors") is not None + if _safetensors_available: + try: + _safetensors_version = importlib_metadata.version("safetensors") + logger.info(f"Safetensors version {_safetensors_version} available.") + except importlib_metadata.PackageNotFoundError: + _safetensors_available = False +else: + logger.info("Disabling Safetensors because USE_TF is set") + _safetensors_available = False + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + +_onnxruntime_version = "N/A" +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +if _onnx_available: + candidates = ( + "onnxruntime", + "onnxruntime-gpu", + "ort_nightly_gpu", + "onnxruntime-directml", + "onnxruntime-openvino", + "ort_nightly_directml", + "onnxruntime-rocm", + "onnxruntime-training", + ) + _onnxruntime_version = None + # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu + for pkg in candidates: + try: + _onnxruntime_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _onnx_available = _onnxruntime_version is not None + if _onnx_available: + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") + +# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. +# _opencv_available = importlib.util.find_spec("opencv-python") is not None +try: + candidates = ( + "opencv-python", + "opencv-contrib-python", + "opencv-python-headless", + "opencv-contrib-python-headless", + ) + _opencv_version = None + for pkg in candidates: + try: + _opencv_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _opencv_available = _opencv_version is not None + if _opencv_available: + logger.debug(f"Successfully imported cv2 version {_opencv_version}") +except importlib_metadata.PackageNotFoundError: + _opencv_available = False + +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported scipy version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + +_librosa_available = importlib.util.find_spec("librosa") is not None +try: + _librosa_version = importlib_metadata.version("librosa") + logger.debug(f"Successfully imported librosa version {_librosa_version}") +except importlib_metadata.PackageNotFoundError: + _librosa_available = False + +_accelerate_available = importlib.util.find_spec("accelerate") is not None +try: + _accelerate_version = importlib_metadata.version("accelerate") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _accelerate_available = False + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if _torch_available: + _torch_version = importlib_metadata.version("torch") + if version.Version(_torch_version) < version.Version("1.12"): + raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12") + + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + +_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None +try: + _k_diffusion_version = importlib_metadata.version("k_diffusion") + logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") +except importlib_metadata.PackageNotFoundError: + _k_diffusion_available = False + +_note_seq_available = importlib.util.find_spec("note_seq") is not None +try: + _note_seq_version = importlib_metadata.version("note_seq") + logger.debug(f"Successfully imported note-seq version {_note_seq_version}") +except importlib_metadata.PackageNotFoundError: + _note_seq_available = False + +_wandb_available = importlib.util.find_spec("wandb") is not None +try: + _wandb_version = importlib_metadata.version("wandb") + logger.debug(f"Successfully imported wandb version {_wandb_version }") +except importlib_metadata.PackageNotFoundError: + _wandb_available = False + + +_tensorboard_available = importlib.util.find_spec("tensorboard") +try: + _tensorboard_version = importlib_metadata.version("tensorboard") + logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") +except importlib_metadata.PackageNotFoundError: + _tensorboard_available = False + + +_compel_available = importlib.util.find_spec("compel") +try: + _compel_version = importlib_metadata.version("compel") + logger.debug(f"Successfully imported compel version {_compel_version}") +except importlib_metadata.PackageNotFoundError: + _compel_available = False + + +_ftfy_available = importlib.util.find_spec("ftfy") is not None +try: + _ftfy_version = importlib_metadata.version("ftfy") + logger.debug(f"Successfully imported ftfy version {_ftfy_version}") +except importlib_metadata.PackageNotFoundError: + _ftfy_available = False + + +_bs4_available = importlib.util.find_spec("bs4") is not None +try: + # importlib metadata under different name + _bs4_version = importlib_metadata.version("beautifulsoup4") + logger.debug(f"Successfully imported ftfy version {_bs4_version}") +except importlib_metadata.PackageNotFoundError: + _bs4_available = False + +_torchsde_available = importlib.util.find_spec("torchsde") is not None +try: + _torchsde_version = importlib_metadata.version("torchsde") + logger.debug(f"Successfully imported torchsde version {_torchsde_version}") +except importlib_metadata.PackageNotFoundError: + _torchsde_available = False + +_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None +try: + _invisible_watermark_version = importlib_metadata.version("invisible-watermark") + logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") +except importlib_metadata.PackageNotFoundError: + _invisible_watermark_available = False + + +_peft_available = importlib.util.find_spec("peft") is not None +try: + _peft_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported peft version {_peft_version}") +except importlib_metadata.PackageNotFoundError: + _peft_available = False + +_torchvision_available = importlib.util.find_spec("torchvision") is not None +try: + _torchvision_version = importlib_metadata.version("torchvision") + logger.debug(f"Successfully imported torchvision version {_torchvision_version}") +except importlib_metadata.PackageNotFoundError: + _torchvision_available = False + +_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None +try: + _sentencepiece_version = importlib_metadata.version("sentencepiece") + logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}") +except importlib_metadata.PackageNotFoundError: + _sentencepiece_available = False + +_matplotlib_available = importlib.util.find_spec("matplotlib") is not None +try: + _matplotlib_version = importlib_metadata.version("matplotlib") + logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") +except importlib_metadata.PackageNotFoundError: + _matplotlib_available = False + +_timm_available = importlib.util.find_spec("timm") is not None +if _timm_available: + try: + _timm_version = importlib_metadata.version("timm") + logger.info(f"Timm version {_timm_version} available.") + except importlib_metadata.PackageNotFoundError: + _timm_available = False + + +def is_timm_available(): + return _timm_available + + +_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None +try: + _bitsandbytes_version = importlib_metadata.version("bitsandbytes") + logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") +except importlib_metadata.PackageNotFoundError: + _bitsandbytes_available = False + +_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) + +_imageio_available = importlib.util.find_spec("imageio") is not None +if _imageio_available: + try: + _imageio_version = importlib_metadata.version("imageio") + logger.debug(f"Successfully imported imageio version {_imageio_version}") + + except importlib_metadata.PackageNotFoundError: + _imageio_available = False + + +def is_torch_available(): + return _torch_available + + +def is_torch_xla_available(): + return _torch_xla_available + + +def is_torch_npu_available(): + return _torch_npu_available + + +def is_flax_available(): + return _flax_available + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_onnx_available(): + return _onnx_available + + +def is_opencv_available(): + return _opencv_available + + +def is_scipy_available(): + return _scipy_available + + +def is_librosa_available(): + return _librosa_available + + +def is_xformers_available(): + return _xformers_available + + +def is_accelerate_available(): + return _accelerate_available + + +def is_k_diffusion_available(): + return _k_diffusion_available + + +def is_note_seq_available(): + return _note_seq_available + + +def is_wandb_available(): + return _wandb_available + + +def is_tensorboard_available(): + return _tensorboard_available + + +def is_compel_available(): + return _compel_available + + +def is_ftfy_available(): + return _ftfy_available + + +def is_bs4_available(): + return _bs4_available + + +def is_torchsde_available(): + return _torchsde_available + + +def is_invisible_watermark_available(): + return _invisible_watermark_available + + +def is_peft_available(): + return _peft_available + + +def is_torchvision_available(): + return _torchvision_available + + +def is_matplotlib_available(): + return _matplotlib_available + + +def is_safetensors_available(): + return _safetensors_available + + +def is_bitsandbytes_available(): + return _bitsandbytes_available + + +def is_google_colab(): + return _is_google_colab + + +def is_sentencepiece_available(): + return _sentencepiece_available + + +def is_imageio_available(): + return _imageio_available + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +""" + +# docstyle-ignore +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + +# docstyle-ignore +OPENCV_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip +install opencv-python` +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + +# docstyle-ignore +LIBROSA_IMPORT_ERROR = """ +{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the +installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. +""" + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + +# docstyle-ignore +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + +# docstyle-ignore +K_DIFFUSION_IMPORT_ERROR = """ +{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip +install k-diffusion` +""" + +# docstyle-ignore +NOTE_SEQ_IMPORT_ERROR = """ +{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip +install note-seq` +""" + +# docstyle-ignore +WANDB_IMPORT_ERROR = """ +{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip +install wandb` +""" + +# docstyle-ignore +TENSORBOARD_IMPORT_ERROR = """ +{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip +install tensorboard` +""" + + +# docstyle-ignore +COMPEL_IMPORT_ERROR = """ +{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` +""" + +# docstyle-ignore +BS4_IMPORT_ERROR = """ +{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: +`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +FTFY_IMPORT_ERROR = """ +{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the +installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TORCHSDE_IMPORT_ERROR = """ +{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` +""" + +# docstyle-ignore +INVISIBLE_WATERMARK_IMPORT_ERROR = """ +{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` +""" + +# docstyle-ignore +PEFT_IMPORT_ERROR = """ +{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft` +""" + +# docstyle-ignore +SAFETENSORS_IMPORT_ERROR = """ +{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors` +""" + +# docstyle-ignore +SENTENCEPIECE_IMPORT_ERROR = """ +{0} requires the sentencepiece library but it was not found in your environment. You can install it with pip: `pip install sentencepiece` +""" + + +# docstyle-ignore +BITSANDBYTES_IMPORT_ERROR = """ +{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes` +""" + +# docstyle-ignore +IMAGEIO_IMPORT_ERROR = """ +{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` +""" + +BACKENDS_MAPPING = OrderedDict( + [ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), + ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), + ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), + ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), + ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), + ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)), + ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), + ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + if name in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + "UnCLIPPipeline", + ] and is_transformers_version("<", "4.25.0"): + raise ImportError( + f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" + " --upgrade transformers \n```" + ) + + if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( + "<", "4.26.0" + ): + raise ImportError( + f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" + " --upgrade transformers \n```" + ) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]: + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_version(operation: str, version: str): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) + + +def is_accelerate_version(operation: str, version: str): + """ + Args: + Compares the current Accelerate version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _accelerate_available: + return False + return compare_versions(parse(_accelerate_version), operation, version) + + +def is_peft_version(operation: str, version: str): + """ + Args: + Compares the current PEFT version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _peft_version: + return False + return compare_versions(parse(_peft_version), operation, version) + + +def is_k_diffusion_version(operation: str, version: str): + """ + Args: + Compares the current k-diffusion version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _k_diffusion_available: + return False + return compare_versions(parse(_k_diffusion_version), operation, version) + + +def get_objects_from_module(module): + """ + Args: + Returns a dict of object names and values in a module, while skipping private/internal objects + module (ModuleType): + Module to extract the objects from. + + Returns: + dict: Dictionary of object names and corresponding values + """ + + objects = {} + for name in dir(module): + if name.startswith("_"): + continue + objects[name] = getattr(module, name) + + return objects + + +class OptionalDependencyNotAvailable(BaseException): + """An error indicating that an optional dependency of Diffusers was not found in the environment.""" + + +class _LazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise RuntimeError( + f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" + f" traceback):\n{e}" + ) from e + + def __reduce__(self): + return (self.__class__, (self._name, self.__file__, self._import_structure)) diff --git a/diffusers/src/diffusers/utils/loading_utils.py b/diffusers/src/diffusers/utils/loading_utils.py new file mode 100755 index 0000000..b36664c --- /dev/null +++ b/diffusers/src/diffusers/utils/loading_utils.py @@ -0,0 +1,137 @@ +import os +import tempfile +from typing import Callable, List, Optional, Union +from urllib.parse import unquote, urlparse + +import PIL.Image +import PIL.ImageOps +import requests + +from .import_utils import BACKENDS_MAPPING, is_imageio_available + + +def load_image( + image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None +) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): + A conversion method to apply to the image after loading it. When set to `None` the image will be converted + "RGB". + + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." + ) + + image = PIL.ImageOps.exif_transpose(image) + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + + return image + + +def load_video( + video: str, + convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, +) -> List[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) + with open(video_path, "wb") as f: + for chunk in video_data: + f.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + if is_imageio_available(): + import imageio + else: + raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images diff --git a/diffusers/src/diffusers/utils/logging.py b/diffusers/src/diffusers/utils/logging.py new file mode 100755 index 0000000..6f93450 --- /dev/null +++ b/diffusers/src/diffusers/utils/logging.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2024 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import ( + CRITICAL, # NOQA + DEBUG, # NOQA + ERROR, # NOQA + FATAL, # NOQA + INFO, # NOQA + NOTSET, # NOQA + WARN, # NOQA + WARNING, # NOQA +) +from typing import Dict, Optional + +from tqdm import auto as tqdm_lib + + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level() -> int: + """ + If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + + if sys.stderr: # only if sys.stderr exists, e.g. when not using pythonw in windows + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict() -> Dict[str, int]: + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Diffusers' root logger as an `int`. + + Returns: + `int`: + Logging level integers which can be one of: + + - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `40`: `diffusers.logging.ERROR` + - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `20`: `diffusers.logging.INFO` + - `10`: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Diffusers' root logger. + + Args: + verbosity (`int`): + Logging level which can be one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info() -> None: + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning() -> None: + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug() -> None: + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error() -> None: + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for 🤗 Diffusers' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs) -> None: + """ + This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar() -> None: + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar() -> None: + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/diffusers/src/diffusers/utils/model_card_template.md b/diffusers/src/diffusers/utils/model_card_template.md new file mode 100755 index 0000000..f41b71e --- /dev/null +++ b/diffusers/src/diffusers/utils/model_card_template.md @@ -0,0 +1,24 @@ +--- +{{ card_data }} +--- + + + +{{ model_description }} + +## Intended uses & limitations + +#### How to use + +```python +# TODO: add an example code snippet for running this diffusion pipeline +``` + +#### Limitations and bias + +[TODO: provide examples of latent issues and potential remediations] + +## Training details + +[TODO: describe the data used to train the model] diff --git a/diffusers/src/diffusers/utils/outputs.py b/diffusers/src/diffusers/utils/outputs.py new file mode 100644 index 0000000..6080a86 --- /dev/null +++ b/diffusers/src/diffusers/utils/outputs.py @@ -0,0 +1,137 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +from collections import OrderedDict +from dataclasses import fields, is_dataclass +from typing import Any, Tuple + +import numpy as np + +from .import_utils import is_torch_available, is_torch_version + + +def is_tensor(x) -> bool: + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + Python dictionary. + + + + You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + first. + + + """ + + def __init_subclass__(cls) -> None: + """Register subclasses as pytree nodes. + + This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with + `static_graph=True` with modules that output `ModelOutput` subclasses. + """ + if is_torch_available(): + import torch.utils._pytree + + if is_torch_version("<", "2.2"): + torch.utils._pytree._register_pytree_node( + cls, + torch.utils._pytree._dict_flatten, + lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + ) + else: + torch.utils._pytree.register_pytree_node( + cls, + torch.utils._pytree._dict_flatten, + lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + ) + + def __post_init__(self) -> None: + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k: Any) -> Any: + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name: Any, value: Any) -> None: + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, *remaining = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, *remaining + + def to_tuple(self) -> Tuple[Any, ...]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) diff --git a/diffusers/src/diffusers/utils/peft_utils.py b/diffusers/src/diffusers/utils/peft_utils.py new file mode 100755 index 0000000..ca55192 --- /dev/null +++ b/diffusers/src/diffusers/utils/peft_utils.py @@ -0,0 +1,295 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PEFT utilities: Utilities related to peft library +""" + +import collections +import importlib +from typing import Optional + +from packaging import version + +from .import_utils import is_peft_available, is_torch_available + + +if is_torch_available(): + import torch + + +def recurse_remove_peft_layers(model): + r""" + Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + has_base_layer_pattern = False + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + has_base_layer_pattern = hasattr(module, "base_layer") + break + + if has_base_layer_pattern: + from peft.utils import _get_submodules + + key_list = [key for key, _ in model.named_modules() if "lora" not in key] + for key in key_list: + try: + parent, target, target_name = _get_submodules(model, key) + except AttributeError: + continue + if hasattr(target, "base_layer"): + setattr(parent, target_name, target.get_base_layer()) + else: + # This is for backwards compatibility with PEFT <= 0.6.2. + # TODO can be removed once that PEFT version is no longer supported. + from peft.tuners.lora import LoraLayer + + for name, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + recurse_remove_peft_layers(module) + + module_replaced = False + + if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): + new_module = torch.nn.Linear( + module.in_features, + module.out_features, + bias=module.bias is not None, + ).to(module.weight.device) + new_module.weight = module.weight + if module.bias is not None: + new_module.bias = module.bias + + module_replaced = True + elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d): + new_module = torch.nn.Conv2d( + module.in_channels, + module.out_channels, + module.kernel_size, + module.stride, + module.padding, + module.dilation, + module.groups, + ).to(module.weight.device) + + new_module.weight = module.weight + if module.bias is not None: + new_module.bias = module.bias + + module_replaced = True + + if module_replaced: + setattr(model, name, new_module) + del module + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return model + + +def scale_lora_layers(model, weight): + """ + Adjust the weightage given to the LoRA layers of the model. + + Args: + model (`torch.nn.Module`): + The model to scale. + weight (`float`): + The weight to be given to the LoRA layers. + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + if weight == 1.0: + return + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.scale_layer(weight) + + +def unscale_lora_layers(model, weight: Optional[float] = None): + """ + Removes the previously passed weight given to the LoRA layers of the model. + + Args: + model (`torch.nn.Module`): + The model to scale. + weight (`float`, *optional*): + The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be + re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct + value. + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + if weight == 1.0: + return + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + if weight is not None and weight != 0: + module.unscale_layer(weight) + elif weight is not None and weight == 0: + for adapter_name in module.active_adapters: + # if weight == 0 unscale should re-set the scale to the original value. + module.set_scale(adapter_name, 1.0) + + +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): + rank_pattern = {} + alpha_pattern = {} + r = lora_alpha = list(rank_dict.values())[0] + + if len(set(rank_dict.values())) > 1: + # get the rank occuring the most number of times + r = collections.Counter(rank_dict.values()).most_common()[0][0] + + # for modules with rank different from the most occuring rank, add it to the `rank_pattern` + rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) + rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} + + if network_alpha_dict is not None and len(network_alpha_dict) > 0: + if len(set(network_alpha_dict.values())) > 1: + # get the alpha occuring the most number of times + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] + + # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) + if is_unet: + alpha_pattern = { + ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v + for k, v in alpha_pattern.items() + } + else: + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} + else: + lora_alpha = set(network_alpha_dict.values()).pop() + + # layer names without the Diffusers specific + target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) + use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) + + lora_config_kwargs = { + "r": r, + "lora_alpha": lora_alpha, + "rank_pattern": rank_pattern, + "alpha_pattern": alpha_pattern, + "target_modules": target_modules, + "use_dora": use_dora, + } + return lora_config_kwargs + + +def get_adapter_name(model): + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return f"default_{len(module.r)}" + return "default_0" + + +def set_adapter_layers(model, enabled=True): + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + # The recent version of PEFT needs to call `enable_adapters` instead + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=enabled) + else: + module.disable_adapters = not enabled + + +def delete_adapter_layers(model, adapter_name): + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "delete_adapter"): + module.delete_adapter(adapter_name) + else: + raise ValueError( + "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" + ) + + # For transformers integration - we need to pop the adapter from the config + if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"): + model.peft_config.pop(adapter_name, None) + # In case all adapters are deleted, we need to delete the config + # and make sure to set the flag to False + if len(model.peft_config) == 0: + del model.peft_config + model._hf_peft_config_loaded = None + + +def set_weights_and_activate_adapters(model, adapter_names, weights): + from peft.tuners.tuners_utils import BaseTunerLayer + + def get_module_weight(weight_for_adapter, module_name): + if not isinstance(weight_for_adapter, dict): + # If weight_for_adapter is a single number, always return it. + return weight_for_adapter + + for layer_name, weight_ in weight_for_adapter.items(): + if layer_name in module_name: + return weight_ + + parts = module_name.split(".") + # e.g. key = "down_blocks.1.attentions.0" + key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}" + block_weight = weight_for_adapter.get(key, 1.0) + + return block_weight + + # iterate over each adapter, make it active and set the corresponding scaling weight + for adapter_name, weight in zip(adapter_names, weights): + for module_name, module in model.named_modules(): + if isinstance(module, BaseTunerLayer): + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + else: + module.active_adapter = adapter_name + module.set_scale(adapter_name, get_module_weight(weight, module_name)) + + # set multiple active adapters + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_names) + else: + module.active_adapter = adapter_names + + +def check_peft_version(min_version: str) -> None: + r""" + Checks if the version of PEFT is compatible. + + Args: + version (`str`): + The version of PEFT to check against. + """ + if not is_peft_available(): + raise ValueError("PEFT is not installed. Please install it with `pip install peft`") + + is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version) + + if not is_peft_version_compatible: + raise ValueError( + f"The version of PEFT you are using is not compatible, please use a version that is greater" + f" than {min_version}" + ) diff --git a/diffusers/src/diffusers/utils/pil_utils.py b/diffusers/src/diffusers/utils/pil_utils.py new file mode 100755 index 0000000..7667807 --- /dev/null +++ b/diffusers/src/diffusers/utils/pil_utils.py @@ -0,0 +1,67 @@ +from typing import List + +import PIL.Image +import PIL.ImageOps +from packaging import version +from PIL import Image + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } + + +def pt_to_pil(images): + """ + Convert a torch image to a PIL image. + """ + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = numpy_to_pil(images) + return images + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: + """ + Prepares a single grid of images. Useful for visualization purposes. + """ + assert len(images) == rows * cols + + if resize is not None: + images = [img.resize((resize, resize)) for img in images] + + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(images): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid diff --git a/diffusers/src/diffusers/utils/state_dict_utils.py b/diffusers/src/diffusers/utils/state_dict_utils.py new file mode 100755 index 0000000..62b114b --- /dev/null +++ b/diffusers/src/diffusers/utils/state_dict_utils.py @@ -0,0 +1,335 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +State dict utilities: utility methods for converting state dicts easily +""" + +import enum + +from .logging import get_logger + + +logger = get_logger(__name__) + + +class StateDictType(enum.Enum): + """ + The mode to use when converting state dicts. + """ + + DIFFUSERS_OLD = "diffusers_old" + KOHYA_SS = "kohya_ss" + PEFT = "peft" + DIFFUSERS = "diffusers" + + +# We need to define a proper mapping for Unet since it uses different output keys than text encoder +# e.g. to_q_lora -> q_proj / to_q +UNET_TO_DIFFUSERS = { + ".to_out_lora.up": ".to_out.0.lora_B", + ".to_out_lora.down": ".to_out.0.lora_A", + ".to_q_lora.down": ".to_q.lora_A", + ".to_q_lora.up": ".to_q.lora_B", + ".to_k_lora.down": ".to_k.lora_A", + ".to_k_lora.up": ".to_k.lora_B", + ".to_v_lora.down": ".to_v.lora_A", + ".to_v_lora.up": ".to_v.lora_B", + ".lora.up": ".lora_B", + ".lora.down": ".lora_A", + ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", +} + + +DIFFUSERS_TO_PEFT = { + ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", + ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", + ".k_proj.lora_linear_layer.up": ".k_proj.lora_B", + ".k_proj.lora_linear_layer.down": ".k_proj.lora_A", + ".v_proj.lora_linear_layer.up": ".v_proj.lora_B", + ".v_proj.lora_linear_layer.down": ".v_proj.lora_A", + ".out_proj.lora_linear_layer.up": ".out_proj.lora_B", + ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", + ".lora_linear_layer.up": ".lora_B", + ".lora_linear_layer.down": ".lora_A", + "text_projection.lora.down.weight": "text_projection.lora_A.weight", + "text_projection.lora.up.weight": "text_projection.lora_B.weight", +} + +DIFFUSERS_OLD_TO_PEFT = { + ".to_q_lora.up": ".q_proj.lora_B", + ".to_q_lora.down": ".q_proj.lora_A", + ".to_k_lora.up": ".k_proj.lora_B", + ".to_k_lora.down": ".k_proj.lora_A", + ".to_v_lora.up": ".v_proj.lora_B", + ".to_v_lora.down": ".v_proj.lora_A", + ".to_out_lora.up": ".out_proj.lora_B", + ".to_out_lora.down": ".out_proj.lora_A", + ".lora_linear_layer.up": ".lora_B", + ".lora_linear_layer.down": ".lora_A", +} + +PEFT_TO_DIFFUSERS = { + ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", + ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", + ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", + ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", + ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", + ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", + ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", + ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", + "to_k.lora_A": "to_k.lora.down", + "to_k.lora_B": "to_k.lora.up", + "to_q.lora_A": "to_q.lora.down", + "to_q.lora_B": "to_q.lora.up", + "to_v.lora_A": "to_v.lora.down", + "to_v.lora_B": "to_v.lora.up", + "to_out.0.lora_A": "to_out.0.lora.down", + "to_out.0.lora_B": "to_out.0.lora.up", +} + +DIFFUSERS_OLD_TO_DIFFUSERS = { + ".to_q_lora.up": ".q_proj.lora_linear_layer.up", + ".to_q_lora.down": ".q_proj.lora_linear_layer.down", + ".to_k_lora.up": ".k_proj.lora_linear_layer.up", + ".to_k_lora.down": ".k_proj.lora_linear_layer.down", + ".to_v_lora.up": ".v_proj.lora_linear_layer.up", + ".to_v_lora.down": ".v_proj.lora_linear_layer.down", + ".to_out_lora.up": ".out_proj.lora_linear_layer.up", + ".to_out_lora.down": ".out_proj.lora_linear_layer.down", + ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector", + ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector", + ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector", + ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector", +} + +PEFT_TO_KOHYA_SS = { + "lora_A": "lora_down", + "lora_B": "lora_up", + # This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys, + # adding prefixes and adding alpha values + # Check `convert_state_dict_to_kohya` for more +} + +PEFT_STATE_DICT_MAPPINGS = { + StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, + StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, +} + +DIFFUSERS_STATE_DICT_MAPPINGS = { + StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, + StateDictType.PEFT: PEFT_TO_DIFFUSERS, +} + +KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS} + +KEYS_TO_ALWAYS_REPLACE = { + ".processor.": ".", +} + + +def convert_state_dict(state_dict, mapping): + r""" + Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + mapping (`dict[str, str]`): + The mapping to use for conversion, the mapping should be a dictionary with the following structure: + - key: the pattern to replace + - value: the pattern to replace with + + Returns: + converted_state_dict (`dict`) + The converted state dict. + """ + converted_state_dict = {} + for k, v in state_dict.items(): + # First, filter out the keys that we always want to replace + for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): + if pattern in k: + new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] + k = k.replace(pattern, new_pattern) + + for pattern in mapping.keys(): + if pattern in k: + new_pattern = mapping[pattern] + k = k.replace(pattern, new_pattern) + break + converted_state_dict[k] = v + return converted_state_dict + + +def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): + r""" + Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or + new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. + """ + if original_type is None: + # Old diffusers to PEFT + if any("to_out_lora" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS_OLD + elif any("lora_linear_layer" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS + else: + raise ValueError("Could not automatically infer state dict type") + + if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + + mapping = PEFT_STATE_DICT_MAPPINGS[original_type] + return convert_state_dict(state_dict, mapping) + + +def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): + r""" + Converts a state dict to new diffusers format. The state dict can be from previous diffusers format + (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will + return the state dict as is. + + The method only supports the conversion from diffusers old, PEFT to diffusers new for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. + kwargs (`dict`, *args*): + Additional arguments to pass to the method. + + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in + `get_peft_model_state_dict` method: + https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 + but we add it here in case we don't want to rely on that method. + """ + peft_adapter_name = kwargs.pop("adapter_name", None) + if peft_adapter_name is not None: + peft_adapter_name = "." + peft_adapter_name + else: + peft_adapter_name = "" + + if original_type is None: + # Old diffusers to PEFT + if any("to_out_lora" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS_OLD + elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): + original_type = StateDictType.PEFT + elif any("lora_linear_layer" in k for k in state_dict.keys()): + # nothing to do + return state_dict + else: + raise ValueError("Could not automatically infer state dict type") + + if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + + mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] + return convert_state_dict(state_dict, mapping) + + +def convert_unet_state_dict_to_peft(state_dict): + r""" + Converts a state dict from UNet format to diffusers format - i.e. by removing some keys + """ + mapping = UNET_TO_DIFFUSERS + return convert_state_dict(state_dict, mapping) + + +def convert_all_state_dict_to_peft(state_dict): + r""" + Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid + `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` + """ + try: + peft_dict = convert_state_dict_to_peft(state_dict) + except Exception as e: + if str(e) == "Could not automatically infer state dict type": + peft_dict = convert_unet_state_dict_to_peft(state_dict) + else: + raise + + if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()): + raise ValueError("Your LoRA was not converted to PEFT") + + return peft_dict + + +def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): + r""" + Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc. + The method only supports the conversion from PEFT to Kohya for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. + kwargs (`dict`, *args*): + Additional arguments to pass to the method. + + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in + `get_peft_model_state_dict` method: + https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 + but we add it here in case we don't want to rely on that method. + """ + try: + import torch + except ImportError: + logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.") + raise + + peft_adapter_name = kwargs.pop("adapter_name", None) + if peft_adapter_name is not None: + peft_adapter_name = "." + peft_adapter_name + else: + peft_adapter_name = "" + + if original_type is None: + if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): + original_type = StateDictType.PEFT + + if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + + # Use the convert_state_dict function with the appropriate mapping + kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT]) + kohya_ss_state_dict = {} + + # Additional logic for replacing header, alpha parameters `.` with `_` in all keys + for kohya_key, weight in kohya_ss_partial_state_dict.items(): + if "text_encoder_2." in kohya_key: + kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") + elif "text_encoder." in kohya_key: + kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") + elif "unet" in kohya_key: + kohya_key = kohya_key.replace("unet", "lora_unet") + elif "lora_magnitude_vector" in kohya_key: + kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale") + + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names + kohya_ss_state_dict[kohya_key] = weight + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) + + return kohya_ss_state_dict diff --git a/diffusers/src/diffusers/utils/testing_utils.py b/diffusers/src/diffusers/utils/testing_utils.py new file mode 100644 index 0000000..be3e998 --- /dev/null +++ b/diffusers/src/diffusers/utils/testing_utils.py @@ -0,0 +1,1035 @@ +import functools +import importlib +import inspect +import io +import logging +import multiprocessing +import os +import random +import re +import struct +import sys +import tempfile +import time +import unittest +import urllib.parse +from contextlib import contextmanager +from io import BytesIO, StringIO +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import PIL.ImageOps +import requests +from numpy.linalg import norm +from packaging import version + +from .import_utils import ( + BACKENDS_MAPPING, + is_compel_available, + is_flax_available, + is_note_seq_available, + is_onnx_available, + is_opencv_available, + is_peft_available, + is_timm_available, + is_torch_available, + is_torch_version, + is_torchsde_available, + is_transformers_available, +) +from .logging import get_logger + + +global_rng = random.Random() + +logger = get_logger(__name__) + +_required_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version +) > version.parse("0.5") +_required_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version +) > version.parse("4.33") + +USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version + +if is_torch_available(): + import torch + + # Set a backend environment variable for any extra module import required for a custom accelerator + if "DIFFUSERS_TEST_BACKEND" in os.environ: + backend = os.environ["DIFFUSERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \ + to enable a specified backend.):\n{e}" + ) from e + + if "DIFFUSERS_TEST_DEVICE" in os.environ: + torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] + try: + # try creating device to see if provided device is valid + _ = torch.device(torch_device) + except RuntimeError as e: + raise RuntimeError( + f"Unknown testing device specified by environment variable `DIFFUSERS_TEST_DEVICE`: {torch_device}" + ) from e + logger.info(f"torch_device overrode to {torch_device}") + else: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + is_torch_higher_equal_than_1_12 = version.parse( + version.parse(torch.__version__).base_version + ) >= version.parse("1.12") + + if is_torch_higher_equal_than_1_12: + # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details + mps_backend_registered = hasattr(torch.backends, "mps") + torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device + + +def torch_all_close(a, b, *args, **kwargs): + if not is_torch_available(): + raise ValueError("PyTorch needs to be installed to use this function.") + if not torch.allclose(a, b, *args, **kwargs): + assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." + return True + + +def numpy_cosine_similarity_distance(a, b): + similarity = np.dot(a, b) / (norm(a) * norm(b)) + distance = 1.0 - similarity.mean() + + return distance + + +def print_tensor_test( + tensor, + limit_to_slices=None, + max_torch_print=None, + filename="test_corrections.txt", + expected_tensor_name="expected_slice", +): + if max_torch_print: + torch.set_printoptions(threshold=10_000) + + test_name = os.environ.get("PYTEST_CURRENT_TEST") + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + if limit_to_slices: + tensor = tensor[0, -3:, -3:, -1] + + tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") + # format is usually: + # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) + output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") + test_file, test_class, test_fn = test_name.split("::") + test_fn = test_fn.split()[0] + with open(filename, "a") as f: + print("::".join([test_file, test_class, test_fn, output_str]), file=f) + + +def get_tests_dir(append_path=None): + """ + Args: + append_path: optional path to append to the tests dir path + Return: + The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is + joined after the `tests` dir the former is provided. + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + tests_dir = os.path.abspath(os.path.dirname(caller__file__)) + + while not tests_dir.endswith("tests"): + tests_dir = os.path.dirname(tests_dir) + + if append_path: + return Path(tests_dir, append_path).as_posix() + else: + return tests_dir + + +# Taken from the following PR: +# https://github.com/huggingface/accelerate/pull/1964 +def str_to_bool(value) -> int: + """ + Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`, + `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"invalid truth value {value}") + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = str_to_bool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) +_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) +_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False) + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def nightly(test_case): + """ + Decorator marking a test that runs nightly in the diffusers CI. + + Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) + + +def is_torch_compile(test_case): + """ + Decorator marking a test that runs compile tests in the diffusers CI. + + Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case) + + +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. + """ + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def require_torch_2(test_case): + """ + Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. + """ + return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( + test_case + ) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( + test_case + ) + + +# These decorators are for accelerator-specific behaviours that are not GPU-specific +def require_torch_accelerator(test_case): + """Decorator marking a test that requires an accelerator backend and PyTorch.""" + return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")( + test_case + ) + + +def require_torch_multi_gpu(test_case): + """ + Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without + multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests + -k "multi_gpu" + """ + if not is_torch_available(): + return unittest.skip("test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) + + +def require_torch_accelerator_with_fp16(test_case): + """Decorator marking a test that requires an accelerator with support for the FP16 data type.""" + return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")( + test_case + ) + + +def require_torch_accelerator_with_fp64(test_case): + """Decorator marking a test that requires an accelerator with support for the FP64 data type.""" + return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")( + test_case + ) + + +def require_torch_accelerator_with_training(test_case): + """Decorator marking a test that requires an accelerator with support for training.""" + return unittest.skipUnless( + is_torch_available() and backend_supports_training(torch_device), + "test requires accelerator with training support", + )(test_case) + + +def skip_mps(test_case): + """Decorator marking a test to skip if torch_device is 'mps'""" + return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) + + +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed + """ + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) + + +def require_compel(test_case): + """ + Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when + the library is not installed. + """ + return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case) + + +def require_onnxruntime(test_case): + """ + Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. + """ + return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) + + +def require_note_seq(test_case): + """ + Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed. + """ + return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) + + +def require_torchsde(test_case): + """ + Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. + """ + return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) + + +def require_peft_backend(test_case): + """ + Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and + transformers. + """ + return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) + + +def require_timm(test_case): + """ + Decorator marking a test that requires timm. These tests are skipped when timm isn't installed. + """ + return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) + + +def require_peft_version_greater(peft_version): + """ + Decorator marking a test that requires PEFT backend with a specific version, this would require some specific + versions of PEFT and transformers. + """ + + def decorator(test_case): + correct_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version + ) > version.parse(peft_version) + return unittest.skipUnless( + correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}" + )(test_case) + + return decorator + + +def require_accelerate_version_greater(accelerate_version): + def decorator(test_case): + correct_accelerate_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("accelerate")).base_version + ) > version.parse(accelerate_version) + return unittest.skipUnless( + correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}." + )(test_case) + + return decorator + + +def deprecate_after_peft_backend(test_case): + """ + Decorator marking a test that will be skipped after PEFT backend + """ + return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case) + + +def get_python_version(): + sys_info = sys.version_info + major, minor = sys_info.major, sys_info.minor + return major, minor + + +def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: + if isinstance(arry, str): + if local_path is not None: + # local_path can be passed to correct images of tests + return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix() + elif arry.startswith("http://") or arry.startswith("https://"): + response = requests.get(arry) + response.raise_for_status() + arry = np.load(BytesIO(response.content)) + elif os.path.isfile(arry): + arry = np.load(arry) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" + ) + elif isinstance(arry, np.ndarray): + pass + else: + raise ValueError( + "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" + " ndarray." + ) + + return arry + + +def load_pt(url: str): + response = requests.get(url) + response.raise_for_status() + arry = torch.load(BytesIO(response.content)) + return arry + + +def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def preprocess_image(image: PIL.Image, batch_size: int): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str: + if output_gif_path is None: + output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name + + image[0].save( + output_gif_path, + save_all=True, + append_images=image[1:], + optimize=False, + duration=100, + loop=0, + ) + return output_gif_path + + +@contextmanager +def buffered_writer(raw_f): + f = io.BufferedWriter(raw_f) + yield f + f.flush() + + +def export_to_ply(mesh, output_ply_path: str = None): + """ + Write a PLY file for a mesh. + """ + if output_ply_path is None: + output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name + + coords = mesh.verts.detach().cpu().numpy() + faces = mesh.faces.cpu().numpy() + rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1) + + with buffered_writer(open(output_ply_path, "wb")) as f: + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + if rgb is not None: + f.write(b"property uchar red\n") + f.write(b"property uchar green\n") + f.write(b"property uchar blue\n") + if faces is not None: + f.write(bytes(f"element face {len(faces)}\n", "ascii")) + f.write(b"property list uchar int vertex_index\n") + f.write(b"end_header\n") + + if rgb is not None: + rgb = (rgb * 255.499).round().astype(int) + vertices = [ + (*coord, *rgb) + for coord, rgb in zip( + coords.tolist(), + rgb.tolist(), + ) + ] + format = struct.Struct("<3f3B") + for item in vertices: + f.write(format.pack(*item)) + else: + format = struct.Struct("<3f") + for vertex in coords.tolist(): + f.write(format.pack(*vertex)) + + if faces is not None: + format = struct.Struct(" str: + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + return output_video_path + + +def load_hf_numpy(path) -> np.ndarray: + base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main" + + if not path.startswith("http://") and not path.startswith("https://"): + path = os.path.join(base_url, urllib.parse.quote(path)) + + return load_numpy(path) + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, id): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should + pytest do internal changes - also it calls default internal methods of terminalreporter which + can be hijacked by various `pytest-` plugins and interfere. + + """ + from _pytest.config import create_terminal_writer + + if not len(id): + id = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dir = "reports" + Path(dir).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dir}/{id}_{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + with open(report_files["passes"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle + + +# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905 +def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): + """ + To decorate flaky tests. They will be retried on failures. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*): + If provided, will wait that number of seconds before retrying the test. + description (`str`, *optional*): + A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors, + etc.) + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + + except Exception as err: + print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + +# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. + """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f'{results["error"]}') + + +class CaptureLogger: + """ + Args: + Context manager to capture `logging` streams + logger: 'logging` logger object + Returns: + The captured output is available via `self.out` + Example: + ```python + >>> from diffusers import logging + >>> from diffusers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" + + +def enable_full_determinism(): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cuda.matmul.allow_tf32 = False + + +def disable_full_determinism(): + os.environ["CUDA_LAUNCH_BLOCKING"] = "0" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" + torch.use_deterministic_algorithms(False) + + +# Utils for custom and alternative accelerator devices +def _is_torch_fp16_available(device): + if not is_torch_available(): + return False + + import torch + + device = torch.device(device) + + try: + x = torch.zeros((2, 2), dtype=torch.float16).to(device) + _ = torch.mul(x, x) + return True + + except Exception as e: + if device.type == "cuda": + raise ValueError( + f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}" + ) + + return False + + +def _is_torch_fp64_available(device): + if not is_torch_available(): + return False + + import torch + + device = torch.device(device) + + try: + x = torch.zeros((2, 2), dtype=torch.float64).to(device) + _ = torch.mul(x, x) + return True + + except Exception as e: + if device.type == "cuda": + raise ValueError( + f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}" + ) + + return False + + +# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch +if is_torch_available(): + # Behaviour flags + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} + + # Function definitions + BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} + BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} + BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} + + +# This dispatches a defined function according to the accelerator from the function definitions. +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): + if device not in dispatch_table: + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values. Need to guard against 'None' instead at + # user level + if fn is None: + return None + + return fn(*args, **kwargs) + + +# These are callables which automatically dispatch the function specific to the accelerator +def backend_manual_seed(device: str, seed: int): + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) + + +def backend_empty_cache(device: str): + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) + + +def backend_device_count(device: str): + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) + + +# These are callables which return boolean behaviour flags and can be used to specify some +# device agnostic alternative where the feature is unsupported. +def backend_supports_training(device: str): + if not is_torch_available(): + return False + + if device not in BACKEND_SUPPORTS_TRAINING: + device = "default" + + return BACKEND_SUPPORTS_TRAINING[device] + + +# Guard for when Torch is not available +if is_torch_available(): + # Update device function dict mapping + def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): + try: + # Try to import the function directly + spec_fn = getattr(device_spec_module, attribute_name) + device_fn_dict[torch_device] = spec_fn + except AttributeError as e: + # If the function doesn't exist, and there is no default, throw an error + if "default" not in device_fn_dict: + raise AttributeError( + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." + ) from e + + if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ: + device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"] + if not Path(device_spec_path).is_file(): + raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}") + + try: + import_name = device_spec_path[: device_spec_path.index(".py")] + except ValueError as e: + raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e + + device_spec_module = importlib.import_module(import_name) + + try: + device_name = device_spec_module.DEVICE_NAME + except AttributeError: + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") + + if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name: + msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" + msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name." + raise ValueError(msg) + + torch_device = device_name + + # Add one entry here for each `BACKEND_*` dictionary. + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") + update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING") diff --git a/diffusers/src/diffusers/utils/torch_utils.py b/diffusers/src/diffusers/utils/torch_utils.py new file mode 100755 index 0000000..0cf75b4 --- /dev/null +++ b/diffusers/src/diffusers/utils/torch_utils.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch utilities: Utilities related to PyTorch +""" + +from typing import List, Optional, Tuple, Union + +from . import logging +from .import_utils import is_torch_available, is_torch_version + + +if is_torch_available(): + import torch + from torch.fft import fftn, fftshift, ifftn, ifftshift + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +try: + from torch._dynamo import allow_in_graph as maybe_allow_in_graph +except (ImportError, ModuleNotFoundError): + + def maybe_allow_in_graph(cls): + return cls + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +def is_compiled_module(module) -> bool: + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) + + +def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": + """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). + + This version of the method comes from here: + https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 + """ + x = x_in + B, C, H, W = x.shape + + # Non-power of 2 images must be float32 + if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: + x = x.to(dtype=torch.float32) + + # FFT + x_freq = fftn(x, dim=(-2, -1)) + x_freq = fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = ifftshift(x_freq, dim=(-2, -1)) + x_filtered = ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(dtype=x_in.dtype) + + +def apply_freeu( + resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs +) -> Tuple["torch.Tensor", "torch.Tensor"]: + """Applies the FreeU mechanism as introduced in https: + //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. + + Args: + resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. + hidden_states (`torch.Tensor`): Inputs to the underlying block. + res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block. + s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. + s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if resolution_idx == 0: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"]) + if resolution_idx == 1: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) + + return hidden_states, res_hidden_states diff --git a/diffusers/src/diffusers/utils/versions.py b/diffusers/src/diffusers/utils/versions.py new file mode 100755 index 0000000..945a397 --- /dev/null +++ b/diffusers/src/diffusers/utils/versions.py @@ -0,0 +1,117 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities for working with package versions +""" + +import importlib.metadata +import operator +import re +import sys +from typing import Optional + +from packaging import version + + +ops = { + "<": operator.lt, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + ">=": operator.ge, + ">": operator.gt, +} + + +def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): + if got_ver is None or want_ver is None: + raise ValueError( + f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider" + f" reinstalling {pkg}." + ) + if not ops[op](version.parse(got_ver), version.parse(want_ver)): + raise ImportError( + f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" + ) + + +def require_version(requirement: str, hint: Optional[str] = None) -> None: + """ + Perform a runtime check of the dependency versions, using the exact same syntax used by pip. + + The installed module version comes from the *site-packages* dir via *importlib.metadata*. + + Args: + requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" + hint (`str`, *optional*): what suggestion to print in case of requirements not being met + + Example: + + ```python + require_version("pandas>1.1.2") + require_version("numpy>1.18.5", "this is important to have for whatever reason") + ```""" + + hint = f"\n{hint}" if hint is not None else "" + + # non-versioned check + if re.match(r"^[\w_\-\d]+$", requirement): + pkg, op, want_ver = requirement, None, None + else: + match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) + if not match: + raise ValueError( + "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but" + f" got {requirement}" + ) + pkg, want_full = match[0] + want_range = want_full.split(",") # there could be multiple requirements + wanted = {} + for w in want_range: + match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) + if not match: + raise ValueError( + "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23," + f" but got {requirement}" + ) + op, want_ver = match[0] + wanted[op] = want_ver + if op not in ops: + raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") + + # special case + if pkg == "python": + got_ver = ".".join([str(x) for x in sys.version_info[:3]]) + for op, want_ver in wanted.items(): + _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) + return + + # check if any version is installed + try: + got_ver = importlib.metadata.version(pkg) + except importlib.metadata.PackageNotFoundError: + raise importlib.metadata.PackageNotFoundError( + f"The '{requirement}' distribution was not found and is required by this application. {hint}" + ) + + # check that the right version is installed if version number or a range was provided + if want_ver is not None: + for op, want_ver in wanted.items(): + _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) + + +def require_version_core(requirement): + """require_version wrapper which emits a core-specific hint on failure""" + hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git main" + return require_version(requirement, hint) diff --git a/diffusers/src/diffusers/video_processor.py b/diffusers/src/diffusers/video_processor.py new file mode 100755 index 0000000..9e2727b --- /dev/null +++ b/diffusers/src/diffusers/video_processor.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + r""" + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) + + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + + # move the number of channels before the number of frames. + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs diff --git a/diffusers/utils/check_config_docstrings.py b/diffusers/utils/check_config_docstrings.py new file mode 100755 index 0000000..626a9a4 --- /dev/null +++ b/diffusers/utils/check_config_docstrings.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +import re + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_config_docstrings.py +PATH_TO_TRANSFORMERS = "src/transformers" + + +# This is to make sure the transformers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "transformers", + os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), + submodule_search_locations=[PATH_TO_TRANSFORMERS], +) +transformers = spec.loader.load_module() + +CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING + +# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`. +# For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)` +_re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)") + + +CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { + "CLIPConfigMixin", + "DecisionTransformerConfigMixin", + "EncoderDecoderConfigMixin", + "RagConfigMixin", + "SpeechEncoderDecoderConfigMixin", + "VisionEncoderDecoderConfigMixin", + "VisionTextDualEncoderConfigMixin", +} + + +def check_config_docstrings_have_checkpoints(): + configs_without_checkpoint = [] + + for config_class in list(CONFIG_MAPPING.values()): + checkpoint_found = False + + # source code of `config_class` + config_source = inspect.getsource(config_class) + checkpoints = _re_checkpoint.findall(config_source) + + for checkpoint in checkpoints: + # Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link. + # For example, `('bert-base-uncased', 'https://huggingface.co/bert-base-uncased')` + ckpt_name, ckpt_link = checkpoint + + # verify the checkpoint name corresponds to the checkpoint link + ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}" + if ckpt_link == ckpt_link_from_name: + checkpoint_found = True + break + + name = config_class.__name__ + if not checkpoint_found and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK: + configs_without_checkpoint.append(name) + + if len(configs_without_checkpoint) > 0: + message = "\n".join(sorted(configs_without_checkpoint)) + raise ValueError(f"The following configurations don't contain any valid checkpoint:\n{message}") + + +if __name__ == "__main__": + check_config_docstrings_have_checkpoints() diff --git a/diffusers/utils/check_copies.py b/diffusers/utils/check_copies.py new file mode 100755 index 0000000..20449e7 --- /dev/null +++ b/diffusers/utils/check_copies.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +import re +import subprocess + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_copies.py +DIFFUSERS_PATH = "src/diffusers" +REPO_PATH = "." + + +def _should_continue(line, indent): + return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None + + +def find_code_in_diffusers(object_name): + """Find and return the code source code of `object_name`.""" + parts = object_name.split(".") + i = 0 + + # First let's find the module where our object lives. + module = parts[i] + while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")): + i += 1 + if i < len(parts): + module = os.path.join(module, parts[i]) + if i >= len(parts): + raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.") + + with open( + os.path.join(DIFFUSERS_PATH, f"{module}.py"), + "r", + encoding="utf-8", + newline="\n", + ) as f: + lines = f.readlines() + + # Now let's find the class / func in the code! + indent = "" + line_index = 0 + for name in parts[i + 1 :]: + while ( + line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None + ): + line_index += 1 + indent += " " + line_index += 1 + + if line_index >= len(lines): + raise ValueError(f" {object_name} does not match any function or class in {module}.") + + # We found the beginning of the class / func, now let's find the end (when the indent diminishes). + start_index = line_index + while line_index < len(lines) and _should_continue(lines[line_index], indent): + line_index += 1 + # Clean up empty lines at the end (if any). + while len(lines[line_index - 1]) <= 1: + line_index -= 1 + + code_lines = lines[start_index:line_index] + return "".join(code_lines) + + +_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)") +_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") +_re_fill_pattern = re.compile(r"]*>") + + +def get_indent(code): + lines = code.split("\n") + idx = 0 + while idx < len(lines) and len(lines[idx]) == 0: + idx += 1 + if idx < len(lines): + return re.search(r"^(\s*)\S", lines[idx]).groups()[0] + return "" + + +def run_ruff(code): + command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) + stdout, _ = process.communicate(input=code.encode()) + return stdout.decode() + + +def stylify(code: str) -> str: + """ + Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`. + As `ruff` does not provide a python api this cannot be done on the fly. + + Args: + code (`str`): The code to format. + + Returns: + `str`: The formatted code. + """ + has_indent = len(get_indent(code)) > 0 + if has_indent: + code = f"class Bla:\n{code}" + formatted_code = run_ruff(code) + return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code + + +def is_copy_consistent(filename, overwrite=False): + """ + Check if the code commented as a copy in `filename` matches the original. + Return the differences or overwrites the content depending on `overwrite`. + """ + with open(filename, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + diffs = [] + line_index = 0 + # Not a for loop cause `lines` is going to change (if `overwrite=True`). + while line_index < len(lines): + search = _re_copy_warning.search(lines[line_index]) + if search is None: + line_index += 1 + continue + + # There is some copied code here, let's retrieve the original. + indent, object_name, replace_pattern = search.groups() + theoretical_code = find_code_in_diffusers(object_name) + theoretical_indent = get_indent(theoretical_code) + + start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 + indent = theoretical_indent + line_index = start_index + + # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment. + should_continue = True + while line_index < len(lines) and should_continue: + line_index += 1 + if line_index >= len(lines): + break + line = lines[line_index] + should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None + # Clean up empty lines at the end (if any). + while len(lines[line_index - 1]) <= 1: + line_index -= 1 + + observed_code_lines = lines[start_index:line_index] + observed_code = "".join(observed_code_lines) + + # Remove any nested `Copied from` comments to avoid circular copies + theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None] + theoretical_code = "\n".join(theoretical_code) + + # Before comparing, use the `replace_pattern` on the original code. + if len(replace_pattern) > 0: + patterns = replace_pattern.replace("with", "").split(",") + patterns = [_re_replace_pattern.search(p) for p in patterns] + for pattern in patterns: + if pattern is None: + continue + obj1, obj2, option = pattern.groups() + theoretical_code = re.sub(obj1, obj2, theoretical_code) + if option.strip() == "all-casing": + theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) + theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) + + # stylify after replacement. To be able to do that, we need the header (class or function definition) + # from the previous line + theoretical_code = stylify(lines[start_index - 1] + theoretical_code) + theoretical_code = theoretical_code[len(lines[start_index - 1]) :] + + # Test for a diff and act accordingly. + if observed_code != theoretical_code: + diffs.append([object_name, start_index]) + if overwrite: + lines = lines[:start_index] + [theoretical_code] + lines[line_index:] + line_index = start_index + 1 + + if overwrite and len(diffs) > 0: + # Warn the user a file has been modified. + print(f"Detected changes, rewriting {filename}.") + with open(filename, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + return diffs + + +def check_copies(overwrite: bool = False): + all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True) + diffs = [] + for filename in all_files: + new_diffs = is_copy_consistent(filename, overwrite) + diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs] + if not overwrite and len(diffs) > 0: + diff = "\n".join(diffs) + raise Exception( + "Found the following copy inconsistencies:\n" + + diff + + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--fix_and_overwrite", + action="store_true", + help="Whether to fix inconsistencies.", + ) + args = parser.parse_args() + + check_copies(args.fix_and_overwrite) diff --git a/diffusers/utils/check_doc_toc.py b/diffusers/utils/check_doc_toc.py new file mode 100755 index 0000000..35ded93 --- /dev/null +++ b/diffusers/utils/check_doc_toc.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from collections import defaultdict + +import yaml + + +PATH_TO_TOC = "docs/source/en/_toctree.yml" + + +def clean_doc_toc(doc_list): + """ + Cleans the table of content of the model documentation by removing duplicates and sorting models alphabetically. + """ + counts = defaultdict(int) + overview_doc = [] + new_doc_list = [] + for doc in doc_list: + if "local" in doc: + counts[doc["local"]] += 1 + + if doc["title"].lower() == "overview": + overview_doc.append({"local": doc["local"], "title": doc["title"]}) + else: + new_doc_list.append(doc) + + doc_list = new_doc_list + duplicates = [key for key, value in counts.items() if value > 1] + + new_doc = [] + for duplicate_key in duplicates: + titles = list({doc["title"] for doc in doc_list if doc["local"] == duplicate_key}) + if len(titles) > 1: + raise ValueError( + f"{duplicate_key} is present several times in the documentation table of content at " + "`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the " + "others." + ) + # Only add this once + new_doc.append({"local": duplicate_key, "title": titles[0]}) + + # Add none duplicate-keys + new_doc.extend([doc for doc in doc_list if "local" not in counts or counts[doc["local"]] == 1]) + new_doc = sorted(new_doc, key=lambda s: s["title"].lower()) + + # "overview" gets special treatment and is always first + if len(overview_doc) > 1: + raise ValueError("{doc_list} has two 'overview' docs which is not allowed.") + + overview_doc.extend(new_doc) + + # Sort + return overview_doc + + +def check_scheduler_doc(overwrite=False): + with open(PATH_TO_TOC, encoding="utf-8") as f: + content = yaml.safe_load(f.read()) + + # Get to the API doc + api_idx = 0 + while content[api_idx]["title"] != "API": + api_idx += 1 + api_doc = content[api_idx]["sections"] + + # Then to the model doc + scheduler_idx = 0 + while api_doc[scheduler_idx]["title"] != "Schedulers": + scheduler_idx += 1 + + scheduler_doc = api_doc[scheduler_idx]["sections"] + new_scheduler_doc = clean_doc_toc(scheduler_doc) + + diff = False + if new_scheduler_doc != scheduler_doc: + diff = True + if overwrite: + api_doc[scheduler_idx]["sections"] = new_scheduler_doc + + if diff: + if overwrite: + content[api_idx]["sections"] = api_doc + with open(PATH_TO_TOC, "w", encoding="utf-8") as f: + f.write(yaml.dump(content, allow_unicode=True)) + else: + raise ValueError( + "The model doc part of the table of content is not properly sorted, run `make style` to fix this." + ) + + +def check_pipeline_doc(overwrite=False): + with open(PATH_TO_TOC, encoding="utf-8") as f: + content = yaml.safe_load(f.read()) + + # Get to the API doc + api_idx = 0 + while content[api_idx]["title"] != "API": + api_idx += 1 + api_doc = content[api_idx]["sections"] + + # Then to the model doc + pipeline_idx = 0 + while api_doc[pipeline_idx]["title"] != "Pipelines": + pipeline_idx += 1 + + diff = False + pipeline_docs = api_doc[pipeline_idx]["sections"] + new_pipeline_docs = [] + + # sort sub pipeline docs + for pipeline_doc in pipeline_docs: + if "section" in pipeline_doc: + sub_pipeline_doc = pipeline_doc["section"] + new_sub_pipeline_doc = clean_doc_toc(sub_pipeline_doc) + if overwrite: + pipeline_doc["section"] = new_sub_pipeline_doc + new_pipeline_docs.append(pipeline_doc) + + # sort overall pipeline doc + new_pipeline_docs = clean_doc_toc(new_pipeline_docs) + + if new_pipeline_docs != pipeline_docs: + diff = True + if overwrite: + api_doc[pipeline_idx]["sections"] = new_pipeline_docs + + if diff: + if overwrite: + content[api_idx]["sections"] = api_doc + with open(PATH_TO_TOC, "w", encoding="utf-8") as f: + f.write(yaml.dump(content, allow_unicode=True)) + else: + raise ValueError( + "The model doc part of the table of content is not properly sorted, run `make style` to fix this." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + args = parser.parse_args() + + check_scheduler_doc(args.fix_and_overwrite) + check_pipeline_doc(args.fix_and_overwrite) diff --git a/diffusers/utils/check_dummies.py b/diffusers/utils/check_dummies.py new file mode 100755 index 0000000..af99eeb --- /dev/null +++ b/diffusers/utils/check_dummies.py @@ -0,0 +1,175 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_dummies.py +PATH_TO_DIFFUSERS = "src/diffusers" + +# Matches is_xxx_available() +_re_backend = re.compile(r"is\_([a-z_]*)_available\(\)") +# Matches from xxx import bla +_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") + + +DUMMY_CONSTANT = """ +{0} = None +""" + +DUMMY_CLASS = """ +class {0}(metaclass=DummyObject): + _backends = {1} + + def __init__(self, *args, **kwargs): + requires_backends(self, {1}) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, {1}) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, {1}) +""" + + +DUMMY_FUNCTION = """ +def {0}(*args, **kwargs): + requires_backends({0}, {1}) +""" + + +def find_backend(line): + """Find one (or multiple) backend in a code line of the init.""" + backends = _re_backend.findall(line) + if len(backends) == 0: + return None + + return "_and_".join(backends) + + +def read_init(): + """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.""" + with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Get to the point we do the actual imports for type checking + line_index = 0 + while not lines[line_index].startswith("if TYPE_CHECKING"): + line_index += 1 + + backend_specific_objects = {} + # Go through the end of the file + while line_index < len(lines): + # If the line contains is_backend_available, we grab all objects associated with the `else` block + backend = find_backend(lines[line_index]) + if backend is not None: + while not lines[line_index].startswith(" else:"): + line_index += 1 + line_index += 1 + objects = [] + # Until we unindent, add backend objects to the list + while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): + line = lines[line_index] + single_line_import_search = _re_single_line_import.search(line) + if single_line_import_search is not None: + objects.extend(single_line_import_search.groups()[0].split(", ")) + elif line.startswith(" " * 12): + objects.append(line[12:-2]) + line_index += 1 + + if len(objects) > 0: + backend_specific_objects[backend] = objects + else: + line_index += 1 + + return backend_specific_objects + + +def create_dummy_object(name, backend_name): + """Create the code for the dummy object corresponding to `name`.""" + if name.isupper(): + return DUMMY_CONSTANT.format(name) + elif name.islower(): + return DUMMY_FUNCTION.format(name, backend_name) + else: + return DUMMY_CLASS.format(name, backend_name) + + +def create_dummy_files(backend_specific_objects=None): + """Create the content of the dummy files.""" + if backend_specific_objects is None: + backend_specific_objects = read_init() + # For special correspondence backend to module name as used in the function requires_modulename + dummy_files = {} + + for backend, objects in backend_specific_objects.items(): + backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" + dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" + dummy_file += "from ..utils import DummyObject, requires_backends\n\n" + dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) + dummy_files[backend] = dummy_file + + return dummy_files + + +def check_dummies(overwrite=False): + """Check if the dummy files are up to date and maybe `overwrite` with the right content.""" + dummy_files = create_dummy_files() + # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py + short_names = {"torch": "pt"} + + # Locate actual dummy modules and read their content. + path = os.path.join(PATH_TO_DIFFUSERS, "utils") + dummy_file_paths = { + backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") + for backend in dummy_files.keys() + } + + actual_dummies = {} + for backend, file_path in dummy_file_paths.items(): + if os.path.isfile(file_path): + with open(file_path, "r", encoding="utf-8", newline="\n") as f: + actual_dummies[backend] = f.read() + else: + actual_dummies[backend] = "" + + for backend in dummy_files.keys(): + if dummy_files[backend] != actual_dummies[backend]: + if overwrite: + print( + f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " + "__init__ has new objects." + ) + with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: + f.write(dummy_files[backend]) + else: + raise ValueError( + "The main __init__ has objects that are not present in " + f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " + "to fix this." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + args = parser.parse_args() + + check_dummies(args.fix_and_overwrite) diff --git a/diffusers/utils/check_inits.py b/diffusers/utils/check_inits.py new file mode 100755 index 0000000..2c51404 --- /dev/null +++ b/diffusers/utils/check_inits.py @@ -0,0 +1,299 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import importlib.util +import os +import re +from pathlib import Path + + +PATH_TO_TRANSFORMERS = "src/transformers" + + +# Matches is_xxx_available() +_re_backend = re.compile(r"is\_([a-z_]*)_available()") +# Catches a one-line _import_struct = {xxx} +_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}") +# Catches a line with a key-values pattern: "bla": ["foo", "bar"] +_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') +# Catches a line if not is_foo_available +_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)") +# Catches a line _import_struct["bla"].append("foo") +_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') +# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"] +_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]") +# Catches a line with an object between quotes and a comma: "MyModel", +_re_quote_object = re.compile(r'^\s+"([^"]+)",') +# Catches a line with objects between brackets only: ["foo", "bar"], +_re_between_brackets = re.compile(r"^\s+\[([^\]]+)\]") +# Catches a line with from foo import bar, bla, boo +_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") +# Catches a line with try: +_re_try = re.compile(r"^\s*try:") +# Catches a line with else: +_re_else = re.compile(r"^\s*else:") + + +def find_backend(line): + """Find one (or multiple) backend in a code line of the init.""" + if _re_test_backend.search(line) is None: + return None + backends = [b[0] for b in _re_backend.findall(line)] + backends.sort() + return "_and_".join(backends) + + +def parse_init(init_file): + """ + Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects + defined + """ + with open(init_file, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + line_index = 0 + while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"): + line_index += 1 + + # If this is a traditional init, just return. + if line_index >= len(lines): + return None + + # First grab the objects without a specific backend in _import_structure + objects = [] + while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None: + line = lines[line_index] + # If we have everything on a single line, let's deal with it. + if _re_one_line_import_struct.search(line): + content = _re_one_line_import_struct.search(line).groups()[0] + imports = re.findall(r"\[([^\]]+)\]", content) + for imp in imports: + objects.extend([obj[1:-1] for obj in imp.split(", ")]) + line_index += 1 + continue + single_line_import_search = _re_import_struct_key_value.search(line) + if single_line_import_search is not None: + imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0] + objects.extend(imports) + elif line.startswith(" " * 8 + '"'): + objects.append(line[9:-3]) + line_index += 1 + + import_dict_objects = {"none": objects} + # Let's continue with backend-specific objects in _import_structure + while not lines[line_index].startswith("if TYPE_CHECKING"): + # If the line is an if not is_backend_available, we grab all objects associated. + backend = find_backend(lines[line_index]) + # Check if the backend declaration is inside a try block: + if _re_try.search(lines[line_index - 1]) is None: + backend = None + + if backend is not None: + line_index += 1 + + # Scroll until we hit the else block of try-except-else + while _re_else.search(lines[line_index]) is None: + line_index += 1 + + line_index += 1 + + objects = [] + # Until we unindent, add backend objects to the list + while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): + line = lines[line_index] + if _re_import_struct_add_one.search(line) is not None: + objects.append(_re_import_struct_add_one.search(line).groups()[0]) + elif _re_import_struct_add_many.search(line) is not None: + imports = _re_import_struct_add_many.search(line).groups()[0].split(", ") + imports = [obj[1:-1] for obj in imports if len(obj) > 0] + objects.extend(imports) + elif _re_between_brackets.search(line) is not None: + imports = _re_between_brackets.search(line).groups()[0].split(", ") + imports = [obj[1:-1] for obj in imports if len(obj) > 0] + objects.extend(imports) + elif _re_quote_object.search(line) is not None: + objects.append(_re_quote_object.search(line).groups()[0]) + elif line.startswith(" " * 8 + '"'): + objects.append(line[9:-3]) + elif line.startswith(" " * 12 + '"'): + objects.append(line[13:-3]) + line_index += 1 + + import_dict_objects[backend] = objects + else: + line_index += 1 + + # At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend + objects = [] + while ( + line_index < len(lines) + and find_backend(lines[line_index]) is None + and not lines[line_index].startswith("else") + ): + line = lines[line_index] + single_line_import_search = _re_import.search(line) + if single_line_import_search is not None: + objects.extend(single_line_import_search.groups()[0].split(", ")) + elif line.startswith(" " * 8): + objects.append(line[8:-2]) + line_index += 1 + + type_hint_objects = {"none": objects} + # Let's continue with backend-specific objects + while line_index < len(lines): + # If the line is an if is_backend_available, we grab all objects associated. + backend = find_backend(lines[line_index]) + # Check if the backend declaration is inside a try block: + if _re_try.search(lines[line_index - 1]) is None: + backend = None + + if backend is not None: + line_index += 1 + + # Scroll until we hit the else block of try-except-else + while _re_else.search(lines[line_index]) is None: + line_index += 1 + + line_index += 1 + + objects = [] + # Until we unindent, add backend objects to the list + while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): + line = lines[line_index] + single_line_import_search = _re_import.search(line) + if single_line_import_search is not None: + objects.extend(single_line_import_search.groups()[0].split(", ")) + elif line.startswith(" " * 12): + objects.append(line[12:-2]) + line_index += 1 + + type_hint_objects[backend] = objects + else: + line_index += 1 + + return import_dict_objects, type_hint_objects + + +def analyze_results(import_dict_objects, type_hint_objects): + """ + Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init. + """ + + def find_duplicates(seq): + return [k for k, v in collections.Counter(seq).items() if v > 1] + + if list(import_dict_objects.keys()) != list(type_hint_objects.keys()): + return ["Both sides of the init do not have the same backends!"] + + errors = [] + for key in import_dict_objects.keys(): + duplicate_imports = find_duplicates(import_dict_objects[key]) + if duplicate_imports: + errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}") + duplicate_type_hints = find_duplicates(type_hint_objects[key]) + if duplicate_type_hints: + errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}") + + if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])): + name = "base imports" if key == "none" else f"{key} backend" + errors.append(f"Differences for {name}:") + for a in type_hint_objects[key]: + if a not in import_dict_objects[key]: + errors.append(f" {a} in TYPE_HINT but not in _import_structure.") + for a in import_dict_objects[key]: + if a not in type_hint_objects[key]: + errors.append(f" {a} in _import_structure but not in TYPE_HINT.") + return errors + + +def check_all_inits(): + """ + Check all inits in the transformers repo and raise an error if at least one does not define the same objects in + both halves. + """ + failures = [] + for root, _, files in os.walk(PATH_TO_TRANSFORMERS): + if "__init__.py" in files: + fname = os.path.join(root, "__init__.py") + objects = parse_init(fname) + if objects is not None: + errors = analyze_results(*objects) + if len(errors) > 0: + errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}" + failures.append("\n".join(errors)) + if len(failures) > 0: + raise ValueError("\n\n".join(failures)) + + +def get_transformers_submodules(): + """ + Returns the list of Transformers submodules. + """ + submodules = [] + for path, directories, files in os.walk(PATH_TO_TRANSFORMERS): + for folder in directories: + # Ignore private modules + if folder.startswith("_"): + directories.remove(folder) + continue + # Ignore leftovers from branches (empty folders apart from pycache) + if len(list((Path(path) / folder).glob("*.py"))) == 0: + continue + short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS)) + submodule = short_path.replace(os.path.sep, ".") + submodules.append(submodule) + for fname in files: + if fname == "__init__.py": + continue + short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS)) + submodule = short_path.replace(".py", "").replace(os.path.sep, ".") + if len(submodule.split(".")) == 1: + submodules.append(submodule) + return submodules + + +IGNORE_SUBMODULES = [ + "convert_pytorch_checkpoint_to_tf2", + "modeling_flax_pytorch_utils", +] + + +def check_submodules(): + # This is to make sure the transformers module imported is the one in the repo. + spec = importlib.util.spec_from_file_location( + "transformers", + os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), + submodule_search_locations=[PATH_TO_TRANSFORMERS], + ) + transformers = spec.loader.load_module() + + module_not_registered = [ + module + for module in get_transformers_submodules() + if module not in IGNORE_SUBMODULES and module not in transformers._import_structure.keys() + ] + if len(module_not_registered) > 0: + list_of_modules = "\n".join(f"- {module}" for module in module_not_registered) + raise ValueError( + "The following submodules are not properly registered in the main init of Transformers:\n" + f"{list_of_modules}\n" + "Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value." + ) + + +if __name__ == "__main__": + check_all_inits() + check_submodules() diff --git a/diffusers/utils/check_repo.py b/diffusers/utils/check_repo.py new file mode 100755 index 0000000..597893f --- /dev/null +++ b/diffusers/utils/check_repo.py @@ -0,0 +1,755 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +import re +import warnings +from collections import OrderedDict +from difflib import get_close_matches +from pathlib import Path + +from diffusers.models.auto import get_values +from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_torch_available + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_repo.py +PATH_TO_DIFFUSERS = "src/diffusers" +PATH_TO_TESTS = "tests" +PATH_TO_DOC = "docs/source/en" + +# Update this list with models that are supposed to be private. +PRIVATE_MODELS = [ + "DPRSpanPredictor", + "RealmBertModel", + "T5Stack", + "TFDPRSpanPredictor", +] + +# Update this list for models that are not tested with a comment explaining the reason it should not be. +# Being in this list is an exception and should **not** be the rule. +IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ + # models to ignore for not tested + "OPTDecoder", # Building part of bigger (tested) model. + "DecisionTransformerGPT2Model", # Building part of bigger (tested) model. + "SegformerDecodeHead", # Building part of bigger (tested) model. + "PLBartEncoder", # Building part of bigger (tested) model. + "PLBartDecoder", # Building part of bigger (tested) model. + "PLBartDecoderWrapper", # Building part of bigger (tested) model. + "BigBirdPegasusEncoder", # Building part of bigger (tested) model. + "BigBirdPegasusDecoder", # Building part of bigger (tested) model. + "BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model. + "DetrEncoder", # Building part of bigger (tested) model. + "DetrDecoder", # Building part of bigger (tested) model. + "DetrDecoderWrapper", # Building part of bigger (tested) model. + "M2M100Encoder", # Building part of bigger (tested) model. + "M2M100Decoder", # Building part of bigger (tested) model. + "Speech2TextEncoder", # Building part of bigger (tested) model. + "Speech2TextDecoder", # Building part of bigger (tested) model. + "LEDEncoder", # Building part of bigger (tested) model. + "LEDDecoder", # Building part of bigger (tested) model. + "BartDecoderWrapper", # Building part of bigger (tested) model. + "BartEncoder", # Building part of bigger (tested) model. + "BertLMHeadModel", # Needs to be setup as decoder. + "BlenderbotSmallEncoder", # Building part of bigger (tested) model. + "BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model. + "BlenderbotEncoder", # Building part of bigger (tested) model. + "BlenderbotDecoderWrapper", # Building part of bigger (tested) model. + "MBartEncoder", # Building part of bigger (tested) model. + "MBartDecoderWrapper", # Building part of bigger (tested) model. + "MegatronBertLMHeadModel", # Building part of bigger (tested) model. + "MegatronBertEncoder", # Building part of bigger (tested) model. + "MegatronBertDecoder", # Building part of bigger (tested) model. + "MegatronBertDecoderWrapper", # Building part of bigger (tested) model. + "PegasusEncoder", # Building part of bigger (tested) model. + "PegasusDecoderWrapper", # Building part of bigger (tested) model. + "DPREncoder", # Building part of bigger (tested) model. + "ProphetNetDecoderWrapper", # Building part of bigger (tested) model. + "RealmBertModel", # Building part of bigger (tested) model. + "RealmReader", # Not regular model. + "RealmScorer", # Not regular model. + "RealmForOpenQA", # Not regular model. + "ReformerForMaskedLM", # Needs to be setup as decoder. + "Speech2Text2DecoderWrapper", # Building part of bigger (tested) model. + "TFDPREncoder", # Building part of bigger (tested) model. + "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFModelMixin ?) + "TFRobertaForMultipleChoice", # TODO: fix + "TrOCRDecoderWrapper", # Building part of bigger (tested) model. + "SeparableConv1D", # Building part of bigger (tested) model. + "FlaxBartForCausalLM", # Building part of bigger (tested) model. + "FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM. + "OPTDecoderWrapper", +] + +# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't +# trigger the common tests. +TEST_FILES_WITH_NO_COMMON_TESTS = [ + "models/decision_transformer/test_modeling_decision_transformer.py", + "models/camembert/test_modeling_camembert.py", + "models/mt5/test_modeling_flax_mt5.py", + "models/mbart/test_modeling_mbart.py", + "models/mt5/test_modeling_mt5.py", + "models/pegasus/test_modeling_pegasus.py", + "models/camembert/test_modeling_tf_camembert.py", + "models/mt5/test_modeling_tf_mt5.py", + "models/xlm_roberta/test_modeling_tf_xlm_roberta.py", + "models/xlm_roberta/test_modeling_flax_xlm_roberta.py", + "models/xlm_prophetnet/test_modeling_xlm_prophetnet.py", + "models/xlm_roberta/test_modeling_xlm_roberta.py", + "models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py", + "models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py", + "models/decision_transformer/test_modeling_decision_transformer.py", +] + +# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and +# should **not** be the rule. +IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ + # models to ignore for model xxx mapping + "DPTForDepthEstimation", + "DecisionTransformerGPT2Model", + "GLPNForDepthEstimation", + "ViltForQuestionAnswering", + "ViltForImagesAndTextClassification", + "ViltForImageAndTextRetrieval", + "ViltForMaskedLM", + "XGLMEncoder", + "XGLMDecoder", + "XGLMDecoderWrapper", + "PerceiverForMultimodalAutoencoding", + "PerceiverForOpticalFlow", + "SegformerDecodeHead", + "FlaxBeitForMaskedImageModeling", + "PLBartEncoder", + "PLBartDecoder", + "PLBartDecoderWrapper", + "BeitForMaskedImageModeling", + "CLIPTextModel", + "CLIPVisionModel", + "TFCLIPTextModel", + "TFCLIPVisionModel", + "FlaxCLIPTextModel", + "FlaxCLIPVisionModel", + "FlaxWav2Vec2ForCTC", + "DetrForSegmentation", + "DPRReader", + "FlaubertForQuestionAnswering", + "FlavaImageCodebook", + "FlavaTextModel", + "FlavaImageModel", + "FlavaMultimodalModel", + "GPT2DoubleHeadsModel", + "LukeForMaskedLM", + "LukeForEntityClassification", + "LukeForEntityPairClassification", + "LukeForEntitySpanClassification", + "OpenAIGPTDoubleHeadsModel", + "RagModel", + "RagSequenceForGeneration", + "RagTokenForGeneration", + "RealmEmbedder", + "RealmForOpenQA", + "RealmScorer", + "RealmReader", + "TFDPRReader", + "TFGPT2DoubleHeadsModel", + "TFOpenAIGPTDoubleHeadsModel", + "TFRagModel", + "TFRagSequenceForGeneration", + "TFRagTokenForGeneration", + "Wav2Vec2ForCTC", + "HubertForCTC", + "SEWForCTC", + "SEWDForCTC", + "XLMForQuestionAnswering", + "XLNetForQuestionAnswering", + "SeparableConv1D", + "VisualBertForRegionToPhraseAlignment", + "VisualBertForVisualReasoning", + "VisualBertForQuestionAnswering", + "VisualBertForMultipleChoice", + "TFWav2Vec2ForCTC", + "TFHubertForCTC", + "MaskFormerForInstanceSegmentation", +] + +# Update this list for models that have multiple model types for the same +# model doc +MODEL_TYPE_TO_DOC_MAPPING = OrderedDict( + [ + ("data2vec-text", "data2vec"), + ("data2vec-audio", "data2vec"), + ("data2vec-vision", "data2vec"), + ] +) + + +# This is to make sure the transformers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "diffusers", + os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), + submodule_search_locations=[PATH_TO_DIFFUSERS], +) +diffusers = spec.loader.load_module() + + +def check_model_list(): + """Check the model list inside the transformers library.""" + # Get the models from the directory structure of `src/diffusers/models/` + models_dir = os.path.join(PATH_TO_DIFFUSERS, "models") + _models = [] + for model in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model) + if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir): + _models.append(model) + + # Get the models from the directory structure of `src/transformers/models/` + models = [model for model in dir(diffusers.models) if not model.startswith("__")] + + missing_models = sorted(set(_models).difference(models)) + if missing_models: + raise Exception( + f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}." + ) + + +# If some modeling modules should be ignored for all checks, they should be added in the nested list +# _ignore_modules of this function. +def get_model_modules(): + """Get the model modules inside the transformers library.""" + _ignore_modules = [ + "modeling_auto", + "modeling_encoder_decoder", + "modeling_marian", + "modeling_mmbt", + "modeling_outputs", + "modeling_retribert", + "modeling_utils", + "modeling_flax_auto", + "modeling_flax_encoder_decoder", + "modeling_flax_utils", + "modeling_speech_encoder_decoder", + "modeling_flax_speech_encoder_decoder", + "modeling_flax_vision_encoder_decoder", + "modeling_transfo_xl_utilities", + "modeling_tf_auto", + "modeling_tf_encoder_decoder", + "modeling_tf_outputs", + "modeling_tf_pytorch_utils", + "modeling_tf_utils", + "modeling_tf_transfo_xl_utilities", + "modeling_tf_vision_encoder_decoder", + "modeling_vision_encoder_decoder", + ] + modules = [] + for model in dir(diffusers.models): + # There are some magic dunder attributes in the dir, we ignore them + if not model.startswith("__"): + model_module = getattr(diffusers.models, model) + for submodule in dir(model_module): + if submodule.startswith("modeling") and submodule not in _ignore_modules: + modeling_module = getattr(model_module, submodule) + if inspect.ismodule(modeling_module): + modules.append(modeling_module) + return modules + + +def get_models(module, include_pretrained=False): + """Get the objects in module that are models.""" + models = [] + model_classes = (diffusers.ModelMixin, diffusers.TFModelMixin, diffusers.FlaxModelMixin) + for attr_name in dir(module): + if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name): + continue + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__: + models.append((attr_name, attr)) + return models + + +def is_a_private_model(model): + """Returns True if the model should not be in the main init.""" + if model in PRIVATE_MODELS: + return True + + # Wrapper, Encoder and Decoder are all privates + if model.endswith("Wrapper"): + return True + if model.endswith("Encoder"): + return True + if model.endswith("Decoder"): + return True + return False + + +def check_models_are_in_init(): + """Checks all models defined in the library are in the main init.""" + models_not_in_init = [] + dir_transformers = dir(diffusers) + for module in get_model_modules(): + models_not_in_init += [ + model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers + ] + + # Remove private models + models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)] + if len(models_not_in_init) > 0: + raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.") + + +# If some test_modeling files should be ignored when checking models are all tested, they should be added in the +# nested list _ignore_files of this function. +def get_model_test_files(): + """Get the model test files. + + The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be + considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files. + """ + + _ignore_files = [ + "test_modeling_common", + "test_modeling_encoder_decoder", + "test_modeling_flax_encoder_decoder", + "test_modeling_flax_speech_encoder_decoder", + "test_modeling_marian", + "test_modeling_tf_common", + "test_modeling_tf_encoder_decoder", + ] + test_files = [] + # Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models` + model_test_root = os.path.join(PATH_TO_TESTS, "models") + model_test_dirs = [] + for x in os.listdir(model_test_root): + x = os.path.join(model_test_root, x) + if os.path.isdir(x): + model_test_dirs.append(x) + + for target_dir in [PATH_TO_TESTS] + model_test_dirs: + for file_or_dir in os.listdir(target_dir): + path = os.path.join(target_dir, file_or_dir) + if os.path.isfile(path): + filename = os.path.split(path)[-1] + if "test_modeling" in filename and os.path.splitext(filename)[0] not in _ignore_files: + file = os.path.join(*path.split(os.sep)[1:]) + test_files.append(file) + + return test_files + + +# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class +# for the all_model_classes variable. +def find_tested_models(test_file): + """Parse the content of test_file to detect what's in all_model_classes""" + # This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class + with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f: + content = f.read() + all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content) + # Check with one less parenthesis as well + all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content) + if len(all_models) > 0: + model_tested = [] + for entry in all_models: + for line in entry.split(","): + name = line.strip() + if len(name) > 0: + model_tested.append(name) + return model_tested + + +def check_models_are_tested(module, test_file): + """Check models defined in module are tested in test_file.""" + # XxxModelMixin are not tested + defined_models = get_models(module) + tested_models = find_tested_models(test_file) + if tested_models is None: + if test_file.replace(os.path.sep, "/") in TEST_FILES_WITH_NO_COMMON_TESTS: + return + return [ + f"{test_file} should define `all_model_classes` to apply common tests to the models it tests. " + + "If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file " + + "`utils/check_repo.py`." + ] + failures = [] + for model_name, _ in defined_models: + if model_name not in tested_models and model_name not in IGNORE_NON_TESTED: + failures.append( + f"{model_name} is defined in {module.__name__} but is not tested in " + + f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file." + + "If common tests should not applied to that model, add its name to `IGNORE_NON_TESTED`" + + "in the file `utils/check_repo.py`." + ) + return failures + + +def check_all_models_are_tested(): + """Check all models are properly tested.""" + modules = get_model_modules() + test_files = get_model_test_files() + failures = [] + for module in modules: + test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file] + if len(test_file) == 0: + failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.") + elif len(test_file) > 1: + failures.append(f"{module.__name__} has several test files: {test_file}.") + else: + test_file = test_file[0] + new_failures = check_models_are_tested(module, test_file) + if new_failures is not None: + failures += new_failures + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + +def get_all_auto_configured_models(): + """Return the list of all models in at least one auto class.""" + result = set() # To avoid duplicates we concatenate all model classes in a set. + if is_torch_available(): + for attr_name in dir(diffusers.models.auto.modeling_auto): + if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"): + result = result | set(get_values(getattr(diffusers.models.auto.modeling_auto, attr_name))) + if is_flax_available(): + for attr_name in dir(diffusers.models.auto.modeling_flax_auto): + if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"): + result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name))) + return list(result) + + +def ignore_unautoclassed(model_name): + """Rules to determine if `name` should be in an auto class.""" + # Special white list + if model_name in IGNORE_NON_AUTO_CONFIGURED: + return True + # Encoder and Decoder should be ignored + if "Encoder" in model_name or "Decoder" in model_name: + return True + return False + + +def check_models_are_auto_configured(module, all_auto_models): + """Check models defined in module are each in an auto class.""" + defined_models = get_models(module) + failures = [] + for model_name, _ in defined_models: + if model_name not in all_auto_models and not ignore_unautoclassed(model_name): + failures.append( + f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. " + "If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file " + "`utils/check_repo.py`." + ) + return failures + + +def check_all_models_are_auto_configured(): + """Check all models are each in an auto class.""" + missing_backends = [] + if not is_torch_available(): + missing_backends.append("PyTorch") + if not is_flax_available(): + missing_backends.append("Flax") + if len(missing_backends) > 0: + missing = ", ".join(missing_backends) + if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: + raise Exception( + "Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the " + f"Transformers repo, the following are missing: {missing}." + ) + else: + warnings.warn( + "Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the " + f"Transformers repo, the following are missing: {missing}. While it's probably fine as long as you " + "didn't make any change in one of those backends modeling files, you should probably execute the " + "command above to be on the safe side." + ) + modules = get_model_modules() + all_auto_models = get_all_auto_configured_models() + failures = [] + for module in modules: + new_failures = check_models_are_auto_configured(module, all_auto_models) + if new_failures is not None: + failures += new_failures + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + +_re_decorator = re.compile(r"^\s*@(\S+)\s+$") + + +def check_decorator_order(filename): + """Check that in the test file `filename` the slow decorator is always last.""" + with open(filename, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + decorator_before = None + errors = [] + for i, line in enumerate(lines): + search = _re_decorator.search(line) + if search is not None: + decorator_name = search.groups()[0] + if decorator_before is not None and decorator_name.startswith("parameterized"): + errors.append(i) + decorator_before = decorator_name + elif decorator_before is not None: + decorator_before = None + return errors + + +def check_all_decorator_order(): + """Check that in all test files, the slow decorator is always last.""" + errors = [] + for fname in os.listdir(PATH_TO_TESTS): + if fname.endswith(".py"): + filename = os.path.join(PATH_TO_TESTS, fname) + new_errors = check_decorator_order(filename) + errors += [f"- {filename}, line {i}" for i in new_errors] + if len(errors) > 0: + msg = "\n".join(errors) + raise ValueError( + "The parameterized decorator (and its variants) should always be first, but this is not the case in the" + f" following files:\n{msg}" + ) + + +def find_all_documented_objects(): + """Parse the content of all doc files to detect which classes and functions it documents""" + documented_obj = [] + for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"): + with open(doc_file, "r", encoding="utf-8", newline="\n") as f: + content = f.read() + raw_doc_objs = re.findall(r"(?:autoclass|autofunction):: transformers.(\S+)\s+", content) + documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs] + for doc_file in Path(PATH_TO_DOC).glob("**/*.md"): + with open(doc_file, "r", encoding="utf-8", newline="\n") as f: + content = f.read() + raw_doc_objs = re.findall(r"\[\[autodoc\]\]\s+(\S+)\s+", content) + documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs] + return documented_obj + + +# One good reason for not being documented is to be deprecated. Put in this list deprecated objects. +DEPRECATED_OBJECTS = [ + "AutoModelWithLMHead", + "BartPretrainedModel", + "DataCollator", + "DataCollatorForSOP", + "GlueDataset", + "GlueDataTrainingArguments", + "LineByLineTextDataset", + "LineByLineWithRefDataset", + "LineByLineWithSOPTextDataset", + "PretrainedBartModel", + "PretrainedFSMTModel", + "SingleSentenceClassificationProcessor", + "SquadDataTrainingArguments", + "SquadDataset", + "SquadExample", + "SquadFeatures", + "SquadV1Processor", + "SquadV2Processor", + "TFAutoModelWithLMHead", + "TFBartPretrainedModel", + "TextDataset", + "TextDatasetForNextSentencePrediction", + "Wav2Vec2ForMaskedLM", + "Wav2Vec2Tokenizer", + "glue_compute_metrics", + "glue_convert_examples_to_features", + "glue_output_modes", + "glue_processors", + "glue_tasks_num_labels", + "squad_convert_examples_to_features", + "xnli_compute_metrics", + "xnli_output_modes", + "xnli_processors", + "xnli_tasks_num_labels", + "TFTrainer", + "TFTrainingArguments", +] + +# Exceptionally, some objects should not be documented after all rules passed. +# ONLY PUT SOMETHING IN THIS LIST AS A LAST RESORT! +UNDOCUMENTED_OBJECTS = [ + "AddedToken", # This is a tokenizers class. + "BasicTokenizer", # Internal, should never have been in the main init. + "CharacterTokenizer", # Internal, should never have been in the main init. + "DPRPretrainedReader", # Like an Encoder. + "DummyObject", # Just picked by mistake sometimes. + "MecabTokenizer", # Internal, should never have been in the main init. + "ModelCard", # Internal type. + "SqueezeBertModule", # Internal building block (should have been called SqueezeBertLayer) + "TFDPRPretrainedReader", # Like an Encoder. + "TransfoXLCorpus", # Internal type. + "WordpieceTokenizer", # Internal, should never have been in the main init. + "absl", # External module + "add_end_docstrings", # Internal, should never have been in the main init. + "add_start_docstrings", # Internal, should never have been in the main init. + "cached_path", # Internal used for downloading models. + "convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights + "logger", # Internal logger + "logging", # External module + "requires_backends", # Internal function +] + +# This list should be empty. Objects in it should get their own doc page. +SHOULD_HAVE_THEIR_OWN_PAGE = [ + # Benchmarks + "PyTorchBenchmark", + "PyTorchBenchmarkArguments", + "TensorFlowBenchmark", + "TensorFlowBenchmarkArguments", +] + + +def ignore_undocumented(name): + """Rules to determine if `name` should be undocumented.""" + # NOT DOCUMENTED ON PURPOSE. + # Constants uppercase are not documented. + if name.isupper(): + return True + # ModelMixins / Encoders / Decoders / Layers / Embeddings / Attention are not documented. + if ( + name.endswith("ModelMixin") + or name.endswith("Decoder") + or name.endswith("Encoder") + or name.endswith("Layer") + or name.endswith("Embeddings") + or name.endswith("Attention") + ): + return True + # Submodules are not documented. + if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile( + os.path.join(PATH_TO_DIFFUSERS, f"{name}.py") + ): + return True + # All load functions are not documented. + if name.startswith("load_tf") or name.startswith("load_pytorch"): + return True + # is_xxx_available functions are not documented. + if name.startswith("is_") and name.endswith("_available"): + return True + # Deprecated objects are not documented. + if name in DEPRECATED_OBJECTS or name in UNDOCUMENTED_OBJECTS: + return True + # MMBT model does not really work. + if name.startswith("MMBT"): + return True + if name in SHOULD_HAVE_THEIR_OWN_PAGE: + return True + return False + + +def check_all_objects_are_documented(): + """Check all models are properly documented.""" + documented_objs = find_all_documented_objects() + modules = diffusers._modules + objects = [c for c in dir(diffusers) if c not in modules and not c.startswith("_")] + undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)] + if len(undocumented_objs) > 0: + raise Exception( + "The following objects are in the public init so should be documented:\n - " + + "\n - ".join(undocumented_objs) + ) + check_docstrings_are_in_md() + check_model_type_doc_match() + + +def check_model_type_doc_match(): + """Check all doc pages have a corresponding model type.""" + model_doc_folder = Path(PATH_TO_DOC) / "model_doc" + model_docs = [m.stem for m in model_doc_folder.glob("*.md")] + + model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys()) + model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types] + + errors = [] + for m in model_docs: + if m not in model_types and m != "auto": + close_matches = get_close_matches(m, model_types) + error_message = f"{m} is not a proper model identifier." + if len(close_matches) > 0: + close_matches = "/".join(close_matches) + error_message += f" Did you mean {close_matches}?" + errors.append(error_message) + + if len(errors) > 0: + raise ValueError( + "Some model doc pages do not match any existing model type:\n" + + "\n".join(errors) + + "\nYou can add any missing model type to the `MODEL_NAMES_MAPPING` constant in " + "models/auto/configuration_auto.py." + ) + + +# Re pattern to catch :obj:`xx`, :class:`xx`, :func:`xx` or :meth:`xx`. +_re_rst_special_words = re.compile(r":(?:obj|func|class|meth):`([^`]+)`") +# Re pattern to catch things between double backquotes. +_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)") +# Re pattern to catch example introduction. +_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE) + + +def is_rst_docstring(docstring): + """ + Returns `True` if `docstring` is written in rst. + """ + if _re_rst_special_words.search(docstring) is not None: + return True + if _re_double_backquotes.search(docstring) is not None: + return True + if _re_rst_example.search(docstring) is not None: + return True + return False + + +def check_docstrings_are_in_md(): + """Check all docstrings are in md""" + files_with_rst = [] + for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"): + with open(file, "r") as f: + code = f.read() + docstrings = code.split('"""') + + for idx, docstring in enumerate(docstrings): + if idx % 2 == 0 or not is_rst_docstring(docstring): + continue + files_with_rst.append(file) + break + + if len(files_with_rst) > 0: + raise ValueError( + "The following files have docstrings written in rst:\n" + + "\n".join([f"- {f}" for f in files_with_rst]) + + "\nTo fix this run `doc-builder convert path_to_py_file` after installing `doc-builder`\n" + "(`pip install git+https://github.com/huggingface/doc-builder`)" + ) + + +def check_repo_quality(): + """Check all models are properly tested and documented.""" + print("Checking all models are included.") + check_model_list() + print("Checking all models are public.") + check_models_are_in_init() + print("Checking all models are properly tested.") + check_all_decorator_order() + check_all_models_are_tested() + print("Checking all objects are properly documented.") + check_all_objects_are_documented() + print("Checking all models are in at least one auto class.") + check_all_models_are_auto_configured() + + +if __name__ == "__main__": + check_repo_quality() diff --git a/diffusers/utils/check_table.py b/diffusers/utils/check_table.py new file mode 100755 index 0000000..80fd566 --- /dev/null +++ b/diffusers/utils/check_table.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import collections +import importlib.util +import os +import re + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_table.py +TRANSFORMERS_PATH = "src/diffusers" +PATH_TO_DOCS = "docs/source/en" +REPO_PATH = "." + + +def _find_text_in_file(filename, start_prompt, end_prompt): + """ + Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty + lines. + """ + with open(filename, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + # Find the start prompt. + start_index = 0 + while not lines[start_index].startswith(start_prompt): + start_index += 1 + start_index += 1 + + end_index = start_index + while not lines[end_index].startswith(end_prompt): + end_index += 1 + end_index -= 1 + + while len(lines[start_index]) <= 1: + start_index += 1 + while len(lines[end_index]) <= 1: + end_index -= 1 + end_index += 1 + return "".join(lines[start_index:end_index]), start_index, end_index, lines + + +# Add here suffixes that are used to identify models, separated by | +ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration" +# Regexes that match TF/Flax/PT model names. +_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") +_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") +# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes. +_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") + + +# This is to make sure the diffusers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "diffusers", + os.path.join(TRANSFORMERS_PATH, "__init__.py"), + submodule_search_locations=[TRANSFORMERS_PATH], +) +diffusers_module = spec.loader.load_module() + + +# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python +def camel_case_split(identifier): + """Split a camelcased `identifier` into words.""" + matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier) + return [m.group(0) for m in matches] + + +def _center_text(text, width): + text_length = 2 if text == "✅" or text == "❌" else len(text) + left_indent = (width - text_length) // 2 + right_indent = width - text_length - left_indent + return " " * left_indent + text + " " * right_indent + + +def get_model_table_from_auto_modules(): + """Generates an up-to-date model table from the content of the auto modules.""" + # Dictionary model names to config. + config_mapping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES + model_name_to_config = { + name: config_mapping_names[code] + for code, name in diffusers_module.MODEL_NAMES_MAPPING.items() + if code in config_mapping_names + } + model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()} + + # Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax. + slow_tokenizers = collections.defaultdict(bool) + fast_tokenizers = collections.defaultdict(bool) + pt_models = collections.defaultdict(bool) + tf_models = collections.defaultdict(bool) + flax_models = collections.defaultdict(bool) + + # Let's lookup through all diffusers object (once). + for attr_name in dir(diffusers_module): + lookup_dict = None + if attr_name.endswith("Tokenizer"): + lookup_dict = slow_tokenizers + attr_name = attr_name[:-9] + elif attr_name.endswith("TokenizerFast"): + lookup_dict = fast_tokenizers + attr_name = attr_name[:-13] + elif _re_tf_models.match(attr_name) is not None: + lookup_dict = tf_models + attr_name = _re_tf_models.match(attr_name).groups()[0] + elif _re_flax_models.match(attr_name) is not None: + lookup_dict = flax_models + attr_name = _re_flax_models.match(attr_name).groups()[0] + elif _re_pt_models.match(attr_name) is not None: + lookup_dict = pt_models + attr_name = _re_pt_models.match(attr_name).groups()[0] + + if lookup_dict is not None: + while len(attr_name) > 0: + if attr_name in model_name_to_prefix.values(): + lookup_dict[attr_name] = True + break + # Try again after removing the last word in the name + attr_name = "".join(camel_case_split(attr_name)[:-1]) + + # Let's build that table! + model_names = list(model_name_to_config.keys()) + model_names.sort(key=str.lower) + columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"] + # We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side). + widths = [len(c) + 2 for c in columns] + widths[0] = max([len(name) for name in model_names]) + 2 + + # Build the table per se + table = "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n" + # Use ":-----:" format to center-aligned table cell texts + table += "|" + "|".join([":" + "-" * (w - 2) + ":" for w in widths]) + "|\n" + + check = {True: "✅", False: "❌"} + for name in model_names: + prefix = model_name_to_prefix[name] + line = [ + name, + check[slow_tokenizers[prefix]], + check[fast_tokenizers[prefix]], + check[pt_models[prefix]], + check[tf_models[prefix]], + check[flax_models[prefix]], + ] + table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n" + return table + + +def check_model_table(overwrite=False): + """Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`.""" + current_table, start_index, end_index, lines = _find_text_in_file( + filename=os.path.join(PATH_TO_DOCS, "index.md"), + start_prompt="", + ) + new_table = get_model_table_from_auto_modules() + + if current_table != new_table: + if overwrite: + with open(os.path.join(PATH_TO_DOCS, "index.md"), "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines[:start_index] + [new_table] + lines[end_index:]) + else: + raise ValueError( + "The model table in the `index.md` has not been updated. Run `make fix-copies` to fix this." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + args = parser.parse_args() + + check_model_table(args.fix_and_overwrite) diff --git a/diffusers/utils/custom_init_isort.py b/diffusers/utils/custom_init_isort.py new file mode 100755 index 0000000..6c2bb7f --- /dev/null +++ b/diffusers/utils/custom_init_isort.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility that sorts the imports in the custom inits of Diffusers. Diffusers uses init files that delay the +import of an object to when it's actually needed. This is to avoid the main init importing all models, which would +make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with +delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the +objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff` +properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half. + +Use from the root of the repo with: + +```bash +python utils/custom_init_isort.py +``` + +which will auto-sort the imports (used in `make style`). + +For a check only (as used in `make quality`) run: + +```bash +python utils/custom_init_isort.py --check_only +``` +""" + +import argparse +import os +import re +from typing import Any, Callable, List, Optional + + +# Path is defined with the intent you should run this script from the root of the repo. +PATH_TO_TRANSFORMERS = "src/diffusers" + +# Pattern that looks at the indentation in a line. +_re_indent = re.compile(r"^(\s*)\S") +# Pattern that matches `"key":" and puts `key` in group 0. +_re_direct_key = re.compile(r'^\s*"([^"]+)":') +# Pattern that matches `_import_structure["key"]` and puts `key` in group 0. +_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]') +# Pattern that matches `"key",` and puts `key` in group 0. +_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$') +# Pattern that matches any `[stuff]` and puts `stuff` in group 0. +_re_bracket_content = re.compile(r"\[([^\]]+)\]") + + +def get_indent(line: str) -> str: + """Returns the indent in given line (as string).""" + search = _re_indent.search(line) + return "" if search is None else search.groups()[0] + + +def split_code_in_indented_blocks( + code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None +) -> List[str]: + """ + Split some code into its indented blocks, starting at a given level. + + Args: + code (`str`): The code to split. + indent_level (`str`): The indent level (as string) to use for identifying the blocks to split. + start_prompt (`str`, *optional*): If provided, only starts splitting at the line where this text is. + end_prompt (`str`, *optional*): If provided, stops splitting at a line where this text is. + + Warning: + The text before `start_prompt` or after `end_prompt` (if provided) is not ignored, just not split. The input `code` + can thus be retrieved by joining the result. + + Returns: + `List[str]`: The list of blocks. + """ + # Let's split the code into lines and move to start_index. + index = 0 + lines = code.split("\n") + if start_prompt is not None: + while not lines[index].startswith(start_prompt): + index += 1 + blocks = ["\n".join(lines[:index])] + else: + blocks = [] + + # This variable contains the block treated at a given time. + current_block = [lines[index]] + index += 1 + # We split into blocks until we get to the `end_prompt` (or the end of the file). + while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)): + # We have a non-empty line with the proper indent -> start of a new block + if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level: + # Store the current block in the result and rest. There are two cases: the line is part of the block (like + # a closing parenthesis) or not. + if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "): + # Line is part of the current block + current_block.append(lines[index]) + blocks.append("\n".join(current_block)) + if index < len(lines) - 1: + current_block = [lines[index + 1]] + index += 1 + else: + current_block = [] + else: + # Line is not part of the current block + blocks.append("\n".join(current_block)) + current_block = [lines[index]] + else: + # Just add the line to the current block + current_block.append(lines[index]) + index += 1 + + # Adds current block if it's nonempty. + if len(current_block) > 0: + blocks.append("\n".join(current_block)) + + # Add final block after end_prompt if provided. + if end_prompt is not None and index < len(lines): + blocks.append("\n".join(lines[index:])) + + return blocks + + +def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any], str]: + """ + Wraps a key function (as used in a sort) to lowercase and ignore underscores. + """ + + def _inner(x): + return key(x).lower().replace("_", "") + + return _inner + + +def sort_objects(objects: List[Any], key: Optional[Callable[[Any], str]] = None) -> List[Any]: + """ + Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased + last). + + Args: + objects (`List[Any]`): + The list of objects to sort. + key (`Callable[[Any], str]`, *optional*): + A function taking an object as input and returning a string, used to sort them by alphabetical order. + If not provided, will default to noop (so a `key` must be provided if the `objects` are not of type string). + + Returns: + `List[Any]`: The sorted list with the same elements as in the inputs + """ + + # If no key is provided, we use a noop. + def noop(x): + return x + + if key is None: + key = noop + # Constants are all uppercase, they go first. + constants = [obj for obj in objects if key(obj).isupper()] + # Classes are not all uppercase but start with a capital, they go second. + classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()] + # Functions begin with a lowercase, they go last. + functions = [obj for obj in objects if not key(obj)[0].isupper()] + + # Then we sort each group. + key1 = ignore_underscore_and_lowercase(key) + return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1) + + +def sort_objects_in_import(import_statement: str) -> str: + """ + Sorts the imports in a single import statement. + + Args: + import_statement (`str`): The import statement in which to sort the imports. + + Returns: + `str`: The same as the input, but with objects properly sorted. + """ + + # This inner function sort imports between [ ]. + def _replace(match): + imports = match.groups()[0] + # If there is one import only, nothing to do. + if "," not in imports: + return f"[{imports}]" + keys = [part.strip().replace('"', "") for part in imports.split(",")] + # We will have a final empty element if the line finished with a comma. + if len(keys[-1]) == 0: + keys = keys[:-1] + return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]" + + lines = import_statement.split("\n") + if len(lines) > 3: + # Here we have to sort internal imports that are on several lines (one per name): + # key: [ + # "object1", + # "object2", + # ... + # ] + + # We may have to ignore one or two lines on each side. + idx = 2 if lines[1].strip() == "[" else 1 + keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])] + sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1]) + sorted_lines = [lines[x[0] + idx] for x in sorted_indices] + return "\n".join(lines[:idx] + sorted_lines + lines[-idx:]) + elif len(lines) == 3: + # Here we have to sort internal imports that are on one separate line: + # key: [ + # "object1", "object2", ... + # ] + if _re_bracket_content.search(lines[1]) is not None: + lines[1] = _re_bracket_content.sub(_replace, lines[1]) + else: + keys = [part.strip().replace('"', "") for part in lines[1].split(",")] + # We will have a final empty element if the line finished with a comma. + if len(keys[-1]) == 0: + keys = keys[:-1] + lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + return "\n".join(lines) + else: + # Finally we have to deal with imports fitting on one line + import_statement = _re_bracket_content.sub(_replace, import_statement) + return import_statement + + +def sort_imports(file: str, check_only: bool = True): + """ + Sort the imports defined in the `_import_structure` of a given init. + + Args: + file (`str`): The path to the init to check/fix. + check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init. + """ + with open(file, encoding="utf-8") as f: + code = f.read() + + # If the file is not a custom init, there is nothing to do. + if "_import_structure" not in code: + return + + # Blocks of indent level 0 + main_blocks = split_code_in_indented_blocks( + code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" + ) + + # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt). + for block_idx in range(1, len(main_blocks) - 1): + # Check if the block contains some `_import_structure`s thingy to sort. + block = main_blocks[block_idx] + block_lines = block.split("\n") + + # Get to the start of the imports. + line_idx = 0 + while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]: + # Skip dummy import blocks + if "import dummy" in block_lines[line_idx]: + line_idx = len(block_lines) + else: + line_idx += 1 + if line_idx >= len(block_lines): + continue + + # Ignore beginning and last line: they don't contain anything. + internal_block_code = "\n".join(block_lines[line_idx:-1]) + indent = get_indent(block_lines[1]) + # Slit the internal block into blocks of indent level 1. + internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent) + # We have two categories of import key: list or _import_structure[key].append/extend + pattern = _re_direct_key if "_import_structure = {" in block_lines[0] else _re_indirect_key + # Grab the keys, but there is a trap: some lines are empty or just comments. + keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks] + # We only sort the lines with a key. + keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None] + sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])] + + # We reorder the blocks by leaving empty lines/comments as they were and reorder the rest. + count = 0 + reordered_blocks = [] + for i in range(len(internal_blocks)): + if keys[i] is None: + reordered_blocks.append(internal_blocks[i]) + else: + block = sort_objects_in_import(internal_blocks[sorted_indices[count]]) + reordered_blocks.append(block) + count += 1 + + # And we put our main block back together with its first and last line. + main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reordered_blocks + [block_lines[-1]]) + + if code != "\n".join(main_blocks): + if check_only: + return True + else: + print(f"Overwriting {file}.") + with open(file, "w", encoding="utf-8") as f: + f.write("\n".join(main_blocks)) + + +def sort_imports_in_all_inits(check_only=True): + """ + Sort the imports defined in the `_import_structure` of all inits in the repo. + + Args: + check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init. + """ + failures = [] + for root, _, files in os.walk(PATH_TO_TRANSFORMERS): + if "__init__.py" in files: + result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only) + if result: + failures = [os.path.join(root, "__init__.py")] + if len(failures) > 0: + raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.") + args = parser.parse_args() + + sort_imports_in_all_inits(check_only=args.check_only) diff --git a/diffusers/utils/fetch_latest_release_branch.py b/diffusers/utils/fetch_latest_release_branch.py new file mode 100755 index 0000000..9bf578a --- /dev/null +++ b/diffusers/utils/fetch_latest_release_branch.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +from packaging.version import parse + + +# GitHub repository details +USER = "huggingface" +REPO = "diffusers" + + +def fetch_all_branches(user, repo): + branches = [] # List to store all branches + page = 1 # Start from first page + while True: + # Make a request to the GitHub API for the branches + response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page}) + + # Check if the request was successful + if response.status_code == 200: + # Add the branches from the current page to the list + branches.extend([branch["name"] for branch in response.json()]) + + # Check if there is a 'next' link for pagination + if "next" in response.links: + page += 1 # Move to the next page + else: + break # Exit loop if there is no next page + else: + print("Failed to retrieve branches:", response.status_code) + break + + return branches + + +def main(): + # Fetch all branches + branches = fetch_all_branches(USER, REPO) + + # Filter branches. + # print(f"Total branches: {len(branches)}") + filtered_branches = [] + for branch in branches: + if branch.startswith("v") and ("-release" in branch or "-patch" in branch): + filtered_branches.append(branch) + # print(f"Filtered: {branch}") + + sorted_branches = sorted(filtered_branches, key=lambda x: parse(x.split("-")[0][1:]), reverse=True) + latest_branch = sorted_branches[0] + # print(f"Latest branch: {latest_branch}") + return latest_branch + + +if __name__ == "__main__": + print(main()) diff --git a/diffusers/utils/fetch_torch_cuda_pipeline_test_matrix.py b/diffusers/utils/fetch_torch_cuda_pipeline_test_matrix.py new file mode 100755 index 0000000..e6a9c4b --- /dev/null +++ b/diffusers/utils/fetch_torch_cuda_pipeline_test_matrix.py @@ -0,0 +1,101 @@ +import json +import logging +import os +from collections import defaultdict +from pathlib import Path + +from huggingface_hub import HfApi + +import diffusers + + +PATH_TO_REPO = Path(__file__).parent.parent.resolve() +ALWAYS_TEST_PIPELINE_MODULES = [ + "controlnet", + "stable_diffusion", + "stable_diffusion_2", + "stable_diffusion_xl", + "stable_diffusion_adapter", + "deepfloyd_if", + "ip_adapters", + "kandinsky", + "kandinsky2_2", + "text_to_video_synthesis", + "wuerstchen", +] +PIPELINE_USAGE_CUTOFF = int(os.getenv("PIPELINE_USAGE_CUTOFF", 50000)) + +logger = logging.getLogger(__name__) +api = HfApi() + + +def filter_pipelines(usage_dict, usage_cutoff=10000): + output = [] + for diffusers_object, usage in usage_dict.items(): + if usage < usage_cutoff: + continue + + is_diffusers_pipeline = hasattr(diffusers.pipelines, diffusers_object) + if not is_diffusers_pipeline: + continue + + output.append(diffusers_object) + + return output + + +def fetch_pipeline_objects(): + models = api.list_models(library="diffusers") + downloads = defaultdict(int) + + for model in models: + is_counted = False + for tag in model.tags: + if tag.startswith("diffusers:"): + is_counted = True + downloads[tag[len("diffusers:") :]] += model.downloads + + if not is_counted: + downloads["other"] += model.downloads + + # Remove 0 downloads + downloads = {k: v for k, v in downloads.items() if v > 0} + pipeline_objects = filter_pipelines(downloads, PIPELINE_USAGE_CUTOFF) + + return pipeline_objects + + +def fetch_pipeline_modules_to_test(): + try: + pipeline_objects = fetch_pipeline_objects() + except Exception as e: + logger.error(e) + raise RuntimeError("Unable to fetch model list from HuggingFace Hub.") + + test_modules = [] + for pipeline_name in pipeline_objects: + module = getattr(diffusers, pipeline_name) + + test_module = module.__module__.split(".")[-2].strip() + test_modules.append(test_module) + + return test_modules + + +def main(): + test_modules = fetch_pipeline_modules_to_test() + test_modules.extend(ALWAYS_TEST_PIPELINE_MODULES) + + # Get unique modules + test_modules = sorted(set(test_modules)) + print(json.dumps(test_modules)) + + save_path = f"{PATH_TO_REPO}/reports" + os.makedirs(save_path, exist_ok=True) + + with open(f"{save_path}/test-pipelines.json", "w") as f: + json.dump({"pipeline_test_modules": test_modules}, f) + + +if __name__ == "__main__": + main() diff --git a/diffusers/utils/get_modified_files.py b/diffusers/utils/get_modified_files.py new file mode 100755 index 0000000..a252bc6 --- /dev/null +++ b/diffusers/utils/get_modified_files.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.: +# python ./utils/get_modified_files.py utils src tests examples +# +# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered +# since the output of this script is fed into Makefile commands it doesn't print a newline after the results + +import re +import subprocess +import sys + + +fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8") +modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split() + +joined_dirs = "|".join(sys.argv[1:]) +regex = re.compile(rf"^({joined_dirs}).*?\.py$") + +relevant_modified_files = [x for x in modified_files if regex.match(x)] +print(" ".join(relevant_modified_files), end="") diff --git a/diffusers/utils/log_reports.py b/diffusers/utils/log_reports.py new file mode 100755 index 0000000..dd1b258 --- /dev/null +++ b/diffusers/utils/log_reports.py @@ -0,0 +1,139 @@ +import argparse +import json +import os +from datetime import date +from pathlib import Path + +from slack_sdk import WebClient +from tabulate import tabulate + + +MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters + +parser = argparse.ArgumentParser() +parser.add_argument("--slack_channel_name", default="diffusers-ci-nightly") + + +def main(slack_channel_name=None): + failed = [] + passed = [] + + group_info = [] + + total_num_failed = 0 + empty_file = False or len(list(Path().glob("*.log"))) == 0 + + total_empty_files = [] + + for log in Path().glob("*.log"): + section_num_failed = 0 + i = 0 + with open(log) as f: + for line in f: + line = json.loads(line) + i += 1 + if line.get("nodeid", "") != "": + test = line["nodeid"] + if line.get("duration", None) is not None: + duration = f'{line["duration"]:.4f}' + if line.get("outcome", "") == "failed": + section_num_failed += 1 + failed.append([test, duration, log.name.split("_")[0]]) + total_num_failed += 1 + else: + passed.append([test, duration, log.name.split("_")[0]]) + empty_file = i == 0 + group_info.append([str(log), section_num_failed, failed]) + total_empty_files.append(empty_file) + os.remove(log) + failed = [] + text = ( + "🌞 There were no failures!" + if not any(total_empty_files) + else "Something went wrong there is at least one empty file - please check GH action results." + ) + no_error_payload = { + "type": "section", + "text": { + "type": "plain_text", + "text": text, + "emoji": True, + }, + } + + message = "" + payload = [ + { + "type": "header", + "text": { + "type": "plain_text", + "text": "🤗 Results of the Diffusers scheduled nightly tests.", + }, + }, + ] + if total_num_failed > 0: + for i, (name, num_failed, failed_tests) in enumerate(group_info): + if num_failed > 0: + if num_failed == 1: + message += f"*{name}: {num_failed} failed test*\n" + else: + message += f"*{name}: {num_failed} failed tests*\n" + failed_table = [] + for test in failed_tests: + failed_table.append(test[0].split("::")) + failed_table = tabulate( + failed_table, + headers=["Test Location", "Test Case", "Test Name"], + showindex="always", + tablefmt="grid", + maxcolwidths=[12, 12, 12], + ) + message += "\n```\n" + failed_table + "\n```" + + if total_empty_files[i]: + message += f"\n*{name}: Warning! Empty file - please check the GitHub action job *\n" + print(f"### {message}") + else: + payload.append(no_error_payload) + + if len(message) > MAX_LEN_MESSAGE: + print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}") + message = message[:MAX_LEN_MESSAGE] + "..." + + if len(message) != 0: + md_report = { + "type": "section", + "text": {"type": "mrkdwn", "text": message}, + } + payload.append(md_report) + action_button = { + "type": "section", + "text": {"type": "mrkdwn", "text": "*For more details:*"}, + "accessory": { + "type": "button", + "text": {"type": "plain_text", "text": "Check Action results", "emoji": True}, + "url": f"https://github.com/huggingface/diffusers/actions/runs/{os.environ['GITHUB_RUN_ID']}", + }, + } + payload.append(action_button) + + date_report = { + "type": "context", + "elements": [ + { + "type": "plain_text", + "text": f"Nightly test results for {date.today()}", + }, + ], + } + payload.append(date_report) + + print(payload) + + client = WebClient(token=os.environ.get("SLACK_API_TOKEN")) + client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args.slack_channel_name) diff --git a/diffusers/utils/notify_benchmarking_status.py b/diffusers/utils/notify_benchmarking_status.py new file mode 100755 index 0000000..c9c6ab4 --- /dev/null +++ b/diffusers/utils/notify_benchmarking_status.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import requests + + +# Configuration +GITHUB_REPO = "huggingface/diffusers" +GITHUB_RUN_ID = os.getenv("GITHUB_RUN_ID") +SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL") + + +def main(args): + action_url = f"https://github.com/{GITHUB_REPO}/actions/runs/{GITHUB_RUN_ID}" + if args.status == "success": + hub_path = "https://huggingface.co/datasets/diffusers/benchmarks/blob/main/collated_results.csv" + message = ( + "✅ New benchmark workflow successfully run.\n" + f"🕸️ GitHub Action URL: {action_url}.\n" + f"🤗 Check out the benchmarks here: {hub_path}." + ) + else: + message = ( + "❌ Something wrong happened in the benchmarking workflow.\n" + f"Check out the GitHub Action to know more: {action_url}." + ) + + payload = {"text": message} + response = requests.post(SLACK_WEBHOOK_URL, json=payload) + + if response.status_code == 200: + print("Notification sent to Slack successfully.") + else: + print("Failed to send notification to Slack.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--status", type=str, default="success", choices=["success", "failure"]) + args = parser.parse_args() + main(args) diff --git a/diffusers/utils/notify_community_pipelines_mirror.py b/diffusers/utils/notify_community_pipelines_mirror.py new file mode 100755 index 0000000..a7d3a31 --- /dev/null +++ b/diffusers/utils/notify_community_pipelines_mirror.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import requests + + +# Configuration +GITHUB_REPO = "huggingface/diffusers" +GITHUB_RUN_ID = os.getenv("GITHUB_RUN_ID") +SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL") +PATH_IN_REPO = os.getenv("PATH_IN_REPO") + + +def main(args): + action_url = f"https://github.com/{GITHUB_REPO}/actions/runs/{GITHUB_RUN_ID}" + if args.status == "success": + hub_path = f"https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main/{PATH_IN_REPO}" + message = ( + "✅ Community pipelines successfully mirrored.\n" + f"🕸️ GitHub Action URL: {action_url}.\n" + f"🤗 Hub location: {hub_path}." + ) + else: + message = f"❌ Something wrong happened. Check out the GitHub Action to know more: {action_url}." + + payload = {"text": message} + response = requests.post(SLACK_WEBHOOK_URL, json=payload) + + if response.status_code == 200: + print("Notification sent to Slack successfully.") + else: + print("Failed to send notification to Slack.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--status", type=str, default="success", choices=["success", "failure"]) + args = parser.parse_args() + main(args) diff --git a/diffusers/utils/notify_slack_about_release.py b/diffusers/utils/notify_slack_about_release.py new file mode 100755 index 0000000..a67dd8b --- /dev/null +++ b/diffusers/utils/notify_slack_about_release.py @@ -0,0 +1,81 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import requests + + +# Configuration +LIBRARY_NAME = "diffusers" +GITHUB_REPO = "huggingface/diffusers" +SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL") + + +def check_pypi_for_latest_release(library_name): + """Check PyPI for the latest release of the library.""" + response = requests.get(f"https://pypi.org/pypi/{library_name}/json") + if response.status_code == 200: + data = response.json() + return data["info"]["version"] + else: + print("Failed to fetch library details from PyPI.") + return None + + +def get_github_release_info(github_repo): + """Fetch the latest release info from GitHub.""" + url = f"https://api.github.com/repos/{github_repo}/releases/latest" + response = requests.get(url) + + if response.status_code == 200: + data = response.json() + return {"tag_name": data["tag_name"], "url": data["html_url"], "release_time": data["published_at"]} + + else: + print("Failed to fetch release info from GitHub.") + return None + + +def notify_slack(webhook_url, library_name, version, release_info): + """Send a notification to a Slack channel.""" + message = ( + f"🚀 New release for {library_name} available: version **{version}** 🎉\n" + f"📜 Release Notes: {release_info['url']}\n" + f"⏱️ Release time: {release_info['release_time']}" + ) + payload = {"text": message} + response = requests.post(webhook_url, json=payload) + + if response.status_code == 200: + print("Notification sent to Slack successfully.") + else: + print("Failed to send notification to Slack.") + + +def main(): + latest_version = check_pypi_for_latest_release(LIBRARY_NAME) + release_info = get_github_release_info(GITHUB_REPO) + parsed_version = release_info["tag_name"].replace("v", "") + + if latest_version and release_info and latest_version == parsed_version: + notify_slack(SLACK_WEBHOOK_URL, LIBRARY_NAME, latest_version, release_info) + else: + print(f"{latest_version=}, {release_info=}, {parsed_version=}") + raise ValueError("There were some problems.") + + +if __name__ == "__main__": + main() diff --git a/diffusers/utils/overwrite_expected_slice.py b/diffusers/utils/overwrite_expected_slice.py new file mode 100755 index 0000000..07778a0 --- /dev/null +++ b/diffusers/utils/overwrite_expected_slice.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from collections import defaultdict + + +def overwrite_file(file, class_name, test_name, correct_line, done_test): + _id = f"{file}_{class_name}_{test_name}" + done_test[_id] += 1 + + with open(file, "r") as f: + lines = f.readlines() + + class_regex = f"class {class_name}(" + test_regex = f"{4 * ' '}def {test_name}(" + line_begin_regex = f"{8 * ' '}{correct_line.split()[0]}" + another_line_begin_regex = f"{16 * ' '}{correct_line.split()[0]}" + in_class = False + in_func = False + in_line = False + insert_line = False + count = 0 + spaces = 0 + + new_lines = [] + for line in lines: + if line.startswith(class_regex): + in_class = True + elif in_class and line.startswith(test_regex): + in_func = True + elif in_class and in_func and (line.startswith(line_begin_regex) or line.startswith(another_line_begin_regex)): + spaces = len(line.split(correct_line.split()[0])[0]) + count += 1 + + if count == done_test[_id]: + in_line = True + + if in_class and in_func and in_line: + if ")" not in line: + continue + else: + insert_line = True + + if in_class and in_func and in_line and insert_line: + new_lines.append(f"{spaces * ' '}{correct_line}") + in_class = in_func = in_line = insert_line = False + else: + new_lines.append(line) + + with open(file, "w") as f: + for line in new_lines: + f.write(line) + + +def main(correct, fail=None): + if fail is not None: + with open(fail, "r") as f: + test_failures = {l.strip() for l in f.readlines()} + else: + test_failures = None + + with open(correct, "r") as f: + correct_lines = f.readlines() + + done_tests = defaultdict(int) + for line in correct_lines: + file, class_name, test_name, correct_line = line.split("::") + if test_failures is None or "::".join([file, class_name, test_name]) in test_failures: + overwrite_file(file, class_name, test_name, correct_line, done_tests) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--correct_filename", help="filename of tests with expected result") + parser.add_argument("--fail_filename", help="filename of test failures", type=str, default=None) + args = parser.parse_args() + + main(args.correct_filename, args.fail_filename) diff --git a/diffusers/utils/print_env.py b/diffusers/utils/print_env.py new file mode 100755 index 0000000..3e4495c --- /dev/null +++ b/diffusers/utils/print_env.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script dumps information about the environment + +import os +import platform +import sys + + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +print("Python version:", sys.version) + +print("OS platform:", platform.platform()) +print("OS architecture:", platform.machine()) + +try: + import torch + + print("Torch version:", torch.__version__) + print("Cuda available:", torch.cuda.is_available()) + print("Cuda version:", torch.version.cuda) + print("CuDNN version:", torch.backends.cudnn.version()) + print("Number of GPUs available:", torch.cuda.device_count()) +except ImportError: + print("Torch version:", None) + +try: + import transformers + + print("transformers version:", transformers.__version__) +except ImportError: + print("transformers version:", None) diff --git a/diffusers/utils/release.py b/diffusers/utils/release.py new file mode 100755 index 0000000..a0800b9 --- /dev/null +++ b/diffusers/utils/release.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re + +import packaging.version + + +PATH_TO_EXAMPLES = "examples/" +REPLACE_PATTERNS = { + "examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'), + "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), + "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), + "doc": (re.compile(r'^(\s*)release\s*=\s*"[^"]+"$', re.MULTILINE), 'release = "VERSION"\n'), +} +REPLACE_FILES = { + "init": "src/diffusers/__init__.py", + "setup": "setup.py", +} +README_FILE = "README.md" + + +def update_version_in_file(fname, version, pattern): + """Update the version in one file using a specific pattern.""" + with open(fname, "r", encoding="utf-8", newline="\n") as f: + code = f.read() + re_pattern, replace = REPLACE_PATTERNS[pattern] + replace = replace.replace("VERSION", version) + code = re_pattern.sub(replace, code) + with open(fname, "w", encoding="utf-8", newline="\n") as f: + f.write(code) + + +def update_version_in_examples(version): + """Update the version in all examples files.""" + for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES): + # Removing some of the folders with non-actively maintained examples from the walk + if "research_projects" in directories: + directories.remove("research_projects") + if "legacy" in directories: + directories.remove("legacy") + for fname in fnames: + if fname.endswith(".py"): + update_version_in_file(os.path.join(folder, fname), version, pattern="examples") + + +def global_version_update(version, patch=False): + """Update the version in all needed files.""" + for pattern, fname in REPLACE_FILES.items(): + update_version_in_file(fname, version, pattern) + if not patch: + update_version_in_examples(version) + + +def clean_main_ref_in_model_list(): + """Replace the links from main doc tp stable doc in the model list of the README.""" + # If the introduction or the conclusion of the list change, the prompts may need to be updated. + _start_prompt = "🤗 Transformers currently provides the following architectures" + _end_prompt = "1. Want to contribute a new model?" + with open(README_FILE, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Find the start of the list. + start_index = 0 + while not lines[start_index].startswith(_start_prompt): + start_index += 1 + start_index += 1 + + index = start_index + # Update the lines in the model list. + while not lines[index].startswith(_end_prompt): + if lines[index].startswith("1."): + lines[index] = lines[index].replace( + "https://huggingface.co/docs/diffusers/main/model_doc", + "https://huggingface.co/docs/diffusers/model_doc", + ) + index += 1 + + with open(README_FILE, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + + +def get_version(): + """Reads the current version in the __init__.""" + with open(REPLACE_FILES["init"], "r") as f: + code = f.read() + default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] + return packaging.version.parse(default_version) + + +def pre_release_work(patch=False): + """Do all the necessary pre-release steps.""" + # First let's get the default version: base version if we are in dev, bump minor otherwise. + default_version = get_version() + if patch and default_version.is_devrelease: + raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") + if default_version.is_devrelease: + default_version = default_version.base_version + elif patch: + default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" + else: + default_version = f"{default_version.major}.{default_version.minor + 1}.0" + + # Now let's ask nicely if that's the right one. + version = input(f"Which version are you releasing? [{default_version}]") + if len(version) == 0: + version = default_version + + print(f"Updating version to {version}.") + global_version_update(version, patch=patch) + + +# if not patch: +# print("Cleaning main README, don't forget to run `make fix-copies`.") +# clean_main_ref_in_model_list() + + +def post_release_work(): + """Do all the necessary post-release steps.""" + # First let's get the current version + current_version = get_version() + dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" + current_version = current_version.base_version + + # Check with the user we got that right. + version = input(f"Which version are we developing now? [{dev_version}]") + if len(version) == 0: + version = dev_version + + print(f"Updating version to {version}.") + global_version_update(version) + + +# print("Cleaning main README, don't forget to run `make fix-copies`.") +# clean_main_ref_in_model_list() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") + parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") + args = parser.parse_args() + if not args.post_release: + pre_release_work(patch=args.patch) + elif args.patch: + print("Nothing to do after a patch :-)") + else: + post_release_work() diff --git a/diffusers/utils/stale.py b/diffusers/utils/stale.py new file mode 100755 index 0000000..c01b6d5 --- /dev/null +++ b/diffusers/utils/stale.py @@ -0,0 +1,68 @@ +# Copyright 2024 The HuggingFace Team, the AllenNLP library authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script to close stale issue. Taken in part from the AllenNLP repository. +https://github.com/allenai/allennlp. +""" + +import os +from datetime import datetime as dt +from datetime import timezone + +from github import Github + + +LABELS_TO_EXEMPT = [ + "good first issue", + "good second issue", + "good difficult issue", + "enhancement", + "new pipeline/model", + "new scheduler", + "wip", +] + + +def main(): + g = Github(os.environ["GITHUB_TOKEN"]) + repo = g.get_repo("huggingface/diffusers") + open_issues = repo.get_issues(state="open") + + for issue in open_issues: + labels = [label.name.lower() for label in issue.get_labels()] + if "stale" in labels: + comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True) + last_comment = comments[0] if len(comments) > 0 else None + if last_comment is not None and last_comment.user.login != "github-actions[bot]": + # Opens the issue if someone other than Stalebot commented. + issue.edit(state="open") + issue.remove_from_labels("stale") + elif ( + (dt.now(timezone.utc) - issue.updated_at).days > 23 + and (dt.now(timezone.utc) - issue.created_at).days >= 30 + and not any(label in LABELS_TO_EXEMPT for label in labels) + ): + # Post a Stalebot notification after 23 days of inactivity. + issue.create_comment( + "This issue has been automatically marked as stale because it has not had " + "recent activity. If you think this still needs to be addressed " + "please comment on this thread.\n\nPlease note that issues that do not follow the " + "[contributing guidelines](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md) " + "are likely to be ignored." + ) + issue.add_to_labels("stale") + + +if __name__ == "__main__": + main() diff --git a/diffusers/utils/update_metadata.py b/diffusers/utils/update_metadata.py new file mode 100755 index 0000000..103a2b9 --- /dev/null +++ b/diffusers/utils/update_metadata.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility that updates the metadata of the Diffusers library in the repository `huggingface/diffusers-metadata`. + +Usage for an update (as used by the GitHub action `update_metadata`): + +```bash +python utils/update_metadata.py +``` + +Script modified from: +https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py +""" + +import argparse +import os +import tempfile + +import pandas as pd +from datasets import Dataset +from huggingface_hub import hf_hub_download, upload_folder + +from diffusers.pipelines.auto_pipeline import ( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + AUTO_INPAINT_PIPELINES_MAPPING, + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, +) + + +PIPELINE_TAG_JSON = "pipeline_tags.json" + + +def get_supported_pipeline_table() -> dict: + """ + Generates a dictionary containing the supported auto classes for each pipeline type, + using the content of the auto modules. + """ + # All supported pipelines for automatic mapping. + all_supported_pipeline_classes = [ + (class_name.__name__, "text-to-image", "AutoPipelineForText2Image") + for _, class_name in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items() + ] + all_supported_pipeline_classes += [ + (class_name.__name__, "image-to-image", "AutoPipelineForImage2Image") + for _, class_name in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items() + ] + all_supported_pipeline_classes += [ + (class_name.__name__, "image-to-image", "AutoPipelineForInpainting") + for _, class_name in AUTO_INPAINT_PIPELINES_MAPPING.items() + ] + all_supported_pipeline_classes = list(set(all_supported_pipeline_classes)) + all_supported_pipeline_classes.sort(key=lambda x: x[0]) + + data = {} + data["pipeline_class"] = [sample[0] for sample in all_supported_pipeline_classes] + data["pipeline_tag"] = [sample[1] for sample in all_supported_pipeline_classes] + data["auto_class"] = [sample[2] for sample in all_supported_pipeline_classes] + + return data + + +def update_metadata(commit_sha: str): + """ + Update the metadata for the Diffusers repo in `huggingface/diffusers-metadata`. + + Args: + commit_sha (`str`): The commit SHA on Diffusers corresponding to this update. + """ + pipelines_table = get_supported_pipeline_table() + pipelines_table = pd.DataFrame(pipelines_table) + pipelines_dataset = Dataset.from_pandas(pipelines_table) + + hub_pipeline_tags_json = hf_hub_download( + repo_id="huggingface/diffusers-metadata", + filename=PIPELINE_TAG_JSON, + repo_type="dataset", + ) + with open(hub_pipeline_tags_json) as f: + hub_pipeline_tags_json = f.read() + + with tempfile.TemporaryDirectory() as tmp_dir: + pipelines_dataset.to_json(os.path.join(tmp_dir, PIPELINE_TAG_JSON)) + + with open(os.path.join(tmp_dir, PIPELINE_TAG_JSON)) as f: + pipeline_tags_json = f.read() + + hub_pipeline_tags_equal = hub_pipeline_tags_json == pipeline_tags_json + if hub_pipeline_tags_equal: + print("No updates, not pushing the metadata files.") + return + + if commit_sha is not None: + commit_message = ( + f"Update with commit {commit_sha}\n\nSee: " + f"https://github.com/huggingface/diffusers/commit/{commit_sha}" + ) + else: + commit_message = "Update" + + upload_folder( + repo_id="huggingface/diffusers-metadata", + folder_path=tmp_dir, + repo_type="dataset", + commit_message=commit_message, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--commit_sha", default=None, type=str, help="The sha of the commit going with this update.") + args = parser.parse_args() + + update_metadata(args.commit_sha) diff --git a/infer/run_validation.py b/infer/run_validation.py new file mode 100644 index 0000000..0a58e22 --- /dev/null +++ b/infer/run_validation.py @@ -0,0 +1,744 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Standard library imports +import os +import argparse +import random +import gc +from pathlib import Path +import json +import numpy as np +import cv2 +from PIL import Image +from tqdm.auto import tqdm + +# PyTorch imports +import torch +from torch.utils.data import DataLoader, Dataset +import torchvision.transforms as TT +from torchvision.transforms.functional import resize +from torchvision.transforms import InterpolationMode + +# Hugging Face imports +from transformers import AutoTokenizer, T5EncoderModel + +# Diffusers imports +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXTransformer3DModel, + CogvideoXBranchModel, + CogVideoXI2VDualInpaintPipeline +) +from diffusers.utils import export_to_video + + +def get_args(): + parser = argparse.ArgumentParser(description="Validation script for VideoPainter") + + # Model and checkpoint + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint directory containing the branch model" + ) + parser.add_argument( + "--lora_path", + type=str, + default=None, + help="Path to the lora file" + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models." + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Revision of pretrained model identifier" + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files" + ) + + # Data + parser.add_argument( + "--meta_file_path", + type=str, + required=True, + help="Path to validation meta data file" + ) + parser.add_argument( + "--condition_data_root", + type=str, + required=True, + help="Path to condition video root directory" + ) + parser.add_argument( + "--video_data_root", + type=str, + required=True, + help="Path to ground truth video root directory" + ) + + # Output + parser.add_argument( + "--output_dir", + type=str, + default="validation_outputs", + help="Directory to save validation outputs" + ) + + # Video params + parser.add_argument( + "--height", + type=int, + default=480, + help="Video height" + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="Video width" + ) + parser.add_argument( + "--max_num_frames", + type=int, + default=49, + help="Maximum number of frames to process" + ) + parser.add_argument( + "--fps", + type=int, + default=8, + help="Frames per second for saving videos" + ) + + # Generation params + parser.add_argument( + "--guidance_scale", + type=float, + default=6.0, + help="Guidance scale for classifier-free guidance" + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + help="Whether to use dynamic cfg" + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Number of inference steps" + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility" + ) + + # Branch model params + parser.add_argument( + "--branch_layer_num", + type=int, + default=4, + help="Number of layers in the branch" + ) + parser.add_argument( + "--add_first", + action="store_true", + help="Enable add_first feature" + ) + parser.add_argument( + "--mask_add", + action="store_true", + help="Enable mask_add feature" + ) + parser.add_argument( + "--mask_background", + action="store_true", + help="Enable mask_background feature" + ) + parser.add_argument( + "--wo_text", + action="store_true", + help="Disable text conditioning" + ) + parser.add_argument( + "--first_frame_gt", + action="store_true", + help="Use first frame from ground truth" + ) + parser.add_argument( + "--conditioning_scale", + type=float, + default=1.0, + help="Conditioning scale for classifier-free guidance" + ) + + # Performance + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Enable memory efficient attention with xformers" + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for validation" + ) + parser.add_argument( + "--num_workers", + type=int, + default=0, + help="Number of dataloader workers" + ) + parser.add_argument( + "--mixed_precision", + type=str, + choices=["no", "fp16", "bf16"], + default="no", + help="Mixed precision type" + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help="Enable TF32 for faster inference" + ) + parser.add_argument( + "--target_frames", + type=int, + default=25, + help="Number of frames to target and save" + ) + + # Add these new arguments + parser.add_argument( + "--total_sections", + type=int, + default=1, + help="Total number of sections to split the dataset into" + ) + parser.add_argument( + "--section", + type=int, + default=0, + help="Which section to process (0-based index)" + ) + + return parser.parse_args() + + +class VideoInpaintingDataset(Dataset): + def __init__( + self, + meta_file_path, + condition_data_root, + video_data_root, + max_num_frames=49, + selected_indices=None + ): + super().__init__() + self.meta_file_path = Path(meta_file_path) + self.condition_data_root = Path(condition_data_root) + self.video_data_root = Path(video_data_root) + self.max_num_frames = max_num_frames + + if not self.meta_file_path.exists(): + raise ValueError(f"Meta file does not exist: {self.meta_file_path}") + if not self.video_data_root.exists(): + raise ValueError(f"Instance videos root folder does not exist: {self.video_data_root}") + if not self.condition_data_root.exists(): + raise ValueError(f"Condition videos root folder does not exist: {self.condition_data_root}") + + self._load_dataset(selected_indices) + + def _load_dataset(self, selected_indices=None): + try: + metadata = json.load(open(self.meta_file_path, "r")) + except: + metadata = [] + with open(self.meta_file_path, 'r') as f: + for line in f: + metadata.append(json.loads(line)) + + # Use default prompt if needed + self.prompts = ["Moving view of a city scene." for _ in range(len(metadata))] + self.videos = [[self.video_data_root / gt_frame_path for gt_frame_path in sample["gt_frames"]] for sample in metadata] + self.condition_videos = [[self.condition_data_root / render_frame_path for render_frame_path in sample["render_frames"]] for sample in metadata] + self.scene_idxs = [sample["scene_idx"] for sample in metadata] + + # cast to str + self.scene_idxs = [f'{scene_idx:06d}' if isinstance(scene_idx, int) else scene_idx for scene_idx in self.scene_idxs] + + + # Filter by selected indices if provided + if selected_indices is not None: + self.prompts = [p for i, p in enumerate(self.prompts) if i in selected_indices] + self.videos = [v for i, v in enumerate(self.videos) if i in selected_indices] + self.condition_videos = [c for i, c in enumerate(self.condition_videos) if i in selected_indices] + self.scene_idxs = [s for i, s in enumerate(self.scene_idxs) if i in selected_indices] + + self.bad_index = [] + + def __len__(self): + return len(self.videos) + + def __getitem__(self, index): + if index in self.bad_index: + print(f"Bad index: {index}") + return self.__getitem__(random.randint(0, len(self.videos) - 1)) + + prompt = self.prompts[index] + video_frames_paths = self.videos[index] + condition_video_frames_paths = self.condition_videos[index] + scene_idx = self.scene_idxs[index] + + try: + video_frames = [cv2.imread(str(frame_path)) for frame_path in video_frames_paths] + condition_video_frames = [cv2.imread(str(frame_path)) for frame_path in condition_video_frames_paths] + except Exception as e: + print(f"Error loading video frames: {e}") + self.bad_index.append(index) + return self.__getitem__(random.randint(0, len(self.videos) - 1)) + + video = np.array(video_frames) + condition_video = np.array(condition_video_frames) + + return { + "prompt": prompt, + "video": video, + "condition_video": condition_video, + "scene_idx": scene_idx, + } + + +class MyWebDataset: + def __init__( + self, + resolution, + max_num_frames, + max_sequence_length=226, + proportion_empty_prompts=0, + first_frame_gt=False, + video_reshape_mode="resize" + ): + self.resolution = resolution + self.max_num_frames = max_num_frames + self.max_sequence_length = max_sequence_length + self.proportion_empty_prompts = proportion_empty_prompts + self.first_frame_gt = first_frame_gt + self.video_reshape_mode = video_reshape_mode + self.resolutions = (max_num_frames, resolution[0], resolution[1]) + def _find_nearest_resolution(self, height, width): + nearest_res = [self.resolutions[0], self.resolutions[1], self.resolutions[2]] + return nearest_res[1], nearest_res[2] + + def _resize_for_rectangle_crop(self, arr, image_size): + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + def __call__(self, examples): + pixel_values = [] + conditioning_pixel_values = [] + input_ids = [] + scene_idxs = [] + + for example in examples: + caption = example["prompt"] + video = example["video"] # frame, height, width, c + condition_video = example["condition_video"] # frame, height, width, c + scene_idx = example["scene_idx"] + frame, height, width, c = video.shape + + # Handle frame count + if frame > self.max_num_frames: + begin_idx = 0 + end_idx = begin_idx + self.max_num_frames + video = video[begin_idx:end_idx] + condition_video = condition_video[begin_idx:end_idx] + frame = end_idx - begin_idx + elif frame <= self.max_num_frames: + remainder = (3 + (frame % 4)) % 4 + if remainder != 0: + video = video[:-remainder] + condition_video = condition_video[:-remainder] + frame = video.shape[0] + + # Pad to max_num_frames frames + num_repeats = self.max_num_frames - frame + last_frame = video[-1:] + repeated_frames = np.repeat(last_frame, num_repeats, axis=0) + video = np.concatenate([video, repeated_frames], axis=0) + + last_condition_frame = condition_video[-1:] + repeated_condition_frames = np.repeat(last_condition_frame, num_repeats, axis=0) + condition_video = np.concatenate([condition_video, repeated_condition_frames], axis=0) + + assert video.shape[0] == self.max_num_frames, f"video shape {video.shape[0]} is not equal to max_num_frames {self.max_num_frames}" + assert condition_video.shape[0] == self.max_num_frames, f"condition_video shape {condition_video.shape[0]} is not equal to max_num_frames {self.max_num_frames}" + + # Resize video and condition_video + video = torch.from_numpy(video).permute(0, 3, 1, 2) + condition_video = torch.from_numpy(condition_video).permute(0, 3, 1, 2) + nearest_res = self._find_nearest_resolution(video.shape[2], video.shape[3]) + + if self.video_reshape_mode == 'center': + video_resized = self._resize_for_rectangle_crop(video, nearest_res) + condition_video_resized = self._resize_for_rectangle_crop(condition_video, nearest_res) + elif self.video_reshape_mode == 'resize': + video_resized = torch.stack([resize(frame, nearest_res) for frame in video], dim=0) + condition_video_resized = torch.stack([resize(frame, nearest_res) for frame in condition_video], dim=0) + else: + raise NotImplementedError + + video = video_resized + condition_video = condition_video_resized + + if self.first_frame_gt: + condition_video[0] = video[0] + + # Convert to PIL images for pipeline input + video_pil = [Image.fromarray(cv2.cvtColor(frame.permute(1, 2, 0).numpy(), cv2.COLOR_BGR2RGB)) + for frame in video] + condition_video_pil = [Image.fromarray(cv2.cvtColor(frame.permute(1, 2, 0).numpy(), cv2.COLOR_BGR2RGB)) + for frame in condition_video] + + pixel_values.append(video_pil) + conditioning_pixel_values.append(condition_video_pil) + input_ids.append(caption) + scene_idxs.append(scene_idx) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + "scene_idxs": scene_idxs, + } + + +def concatenate_images_horizontally(images1, images2, images3, output_type="np"): + ''' + Concatenate three lists of images horizontally. + Args: + images1: List[Image.Image] or List[np.ndarray] + images2: List[Image.Image] or List[np.ndarray] + images3: List[Image.Image] or List[np.ndarray] + Returns: + List[Image.Image] or List[np.ndarray] + ''' + concatenated_images = [] + for img1, img2, img3 in zip(images1, images2, images3): + # Convert images to numpy arrays + arr1 = np.array(img1) + arr2 = np.array(img2) + arr3 = np.array(img3) + + # Concatenate arrays horizontally + concatenated_img = np.concatenate((arr1, arr2, arr3), axis=1) + + # Convert back to PIL Image if needed + if output_type == "pil": + concatenated_img = Image.fromarray(concatenated_img) + concatenated_images.append(concatenated_img) + return concatenated_images + + +def main(): + args = get_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Set device and weight type + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + # Set seed for reproducibility + if args.seed is not None: + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + # Load models + print("Loading models...") + + # Branch + # First load from pretrained model + if os.path.exists(os.path.join(args.checkpoint_path, "branch")): + branch = CogvideoXBranchModel.from_pretrained(os.path.join(args.checkpoint_path, "branch"), torch_dtype=weight_dtype).to(dtype=weight_dtype).cuda() + print(f"Loading branch model from {args.checkpoint_path}") + if args.lora_path is None: + pipe = CogVideoXI2VDualInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, + branch=branch, + torch_dtype=weight_dtype, + ) + else: + print(f"Loading the lora from: {args.lora_path}") + # load the transformer + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=weight_dtype, + id_pool_resample_learnable=False, + ).to(dtype=weight_dtype).cuda() + + pipe = CogVideoXI2VDualInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, + branch=branch, + transformer=transformer, + torch_dtype=weight_dtype, + ) + + pipe.load_lora_weights( + args.lora_path, + weight_name="pytorch_lora_weights.safetensors", + adapter_name="test_1", + target_modules=["transformer"] + ) + # pipe.fuse_lora(lora_scale=1 / lora_rank) + + list_adapters_component_wise = pipe.get_list_adapters() + print(f"list_adapters_component_wise: {list_adapters_component_wise}") + else: + raise ValueError(f"Branch model not found in {args.checkpoint_path}") + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=weight_dtype, + ).to(device=device, dtype=weight_dtype).cuda() + branch = CogvideoXBranchModel.from_transformer( + transformer=transformer, + num_layers=1, + attention_head_dim=transformer.config.attention_head_dim, + num_attention_heads=transformer.config.num_attention_heads, + load_weights_from_transformer=True + ).to(dtype=weight_dtype).cuda() + pipe = CogVideoXI2VDualInpaintPipeline.from_pretrained( + model_path, + branch=branch, + transformer=transformer, + torch_dtype=weight_dtype, + # device_map='balanced', + ) + print("Intializing model from pretrained model") + + pipe.text_encoder.requires_grad_(False) + pipe.transformer.requires_grad_(False) + pipe.vae.requires_grad_(False) + pipe.branch.requires_grad_(False) + + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + pipe.to("cuda") + + # if long_video: + + # pipe.enable_model_cpu_offload() + # pipe.enable_sequential_cpu_offload() + # pipe.vae.enable_slicing() + # pipe.vae.enable_tiling() + + # Enable xformers if requested + if args.enable_xformers_memory_efficient_attention: + try: + pipe.enable_xformers_memory_efficient_attention() + except ImportError: + print("xformers is not available. Make sure it is installed correctly") + + # load the json to check the length + with open(args.meta_file_path, 'r') as f: + metadata = json.load(f) + print(f"Total number of samples: {len(metadata)}") + + # scan through the metadata to check if it exists in the output directory + valid_indices = [] + for i, sample in enumerate(metadata): + # check! + if os.path.exists(os.path.join(args.output_dir, 'frames', f"sample_{sample['scene_idx']}")) and len(os.listdir(os.path.join(args.output_dir, 'frames', f"sample_{sample['scene_idx']}"))) > 24: + print(f"Skipping sample {args.output_dir}/frames/sample_{sample['scene_idx']} because it already exists") + continue + valid_indices.append(i) + + # Calculate indices for the selected section + num_samples = len(valid_indices) + part_size = (num_samples + args.total_sections - 1) // args.total_sections # Ceiling division + start_idx = args.section * part_size + end_idx = min((args.section + 1) * part_size, num_samples) # Ensure we don't go beyond the last sample + + selected_indices = valid_indices[start_idx:end_idx] + print(f"Processing section {args.section+1}/{args.total_sections}: Samples {start_idx} to {end_idx-1} ({end_idx - start_idx} samples)") + + # Create dataset for this section + validation_dataset = VideoInpaintingDataset( + meta_file_path=args.meta_file_path, + condition_data_root=args.condition_data_root, + video_data_root=args.video_data_root, + max_num_frames=args.max_num_frames, + selected_indices=selected_indices + ) + + validation_dataloader = DataLoader( + validation_dataset, + batch_size=args.batch_size, + shuffle=False, + collate_fn=MyWebDataset( + resolution=[args.height, args.width], + max_num_frames=args.max_num_frames, + first_frame_gt=args.first_frame_gt, + video_reshape_mode="center", + ), + pin_memory=True, + num_workers=args.num_workers, + ) + + # Set up generator for reproducibility + generator = torch.Generator(device=device).manual_seed(args.seed) if args.seed else None + + # Run inference + print(f"Running inference on {len(validation_dataset)} validation samples...") + for i, batch in enumerate(tqdm(validation_dataloader)): + + # check! + if os.path.exists(os.path.join(args.output_dir, 'frames', f"sample_{batch['scene_idxs'][0]}")) and len(os.listdir(os.path.join(args.output_dir, 'frames', f"sample_{batch['scene_idxs'][0]}"))) > 24: + print(f"Skipping sample {args.output_dir}/frames/sample_{batch['scene_idxs'][0]} because it already exists") + continue + + # Set up pipeline args + pipeline_args = { + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "image": batch["pixel_values"][0][0], + "video": batch["pixel_values"], + "prompt": batch["input_ids"], + "condition_video": batch["conditioning_pixel_values"], + "num_frames": len(batch["pixel_values"][0]), + "strength": 1.0, + "conditioning_scale": args.conditioning_scale, + "mask_background": args.mask_background, + "add_first": args.add_first, + "wo_text": args.wo_text, + "mask_add": args.mask_add, + "replace_gt": True, # Try both options + } + + # Run inference + output = pipe( + **pipeline_args, + num_inference_steps=args.num_inference_steps, + generator=generator, + output_type="np" + ) + + # Get generated video frames + generated_frames = output.frames[0] + + # Process original video and condition video + original_video = pipe.video_processor.preprocess_video( + pipeline_args['video'][0], + height=generated_frames.shape[1], + width=generated_frames.shape[2] + ) + + condition_video = pipe.video_processor.preprocess_video( + pipeline_args['condition_video'][0], + height=generated_frames.shape[1], + width=generated_frames.shape[2] + ) + + original_video = pipe.video_processor.postprocess_video(video=original_video, output_type="np")[0] + condition_video = pipe.video_processor.postprocess_video(video=condition_video, output_type="np")[0] + + # select target frames + if args.target_frames is not None: + original_video = original_video[:args.target_frames] + condition_video = condition_video[:args.target_frames] + generated_frames = generated_frames[:args.target_frames] + + # Concatenate horizontally + concatenated_frames = concatenate_images_horizontally( + original_video, + condition_video, + generated_frames + ) + + # Save as video + scene_idx = batch["scene_idxs"][0] + video_filename = os.path.join(args.output_dir, 'cat_videos', f"sample_{scene_idx}.mp4") + os.makedirs(os.path.join(args.output_dir, 'cat_videos'), exist_ok=True) + export_to_video(concatenated_frames, video_filename, fps=args.fps) + + # Save individual frames + frames_dir = os.path.join(args.output_dir, 'frames', f"sample_{scene_idx}") + os.makedirs(frames_dir, exist_ok=True) + + for j, frame in enumerate(generated_frames): + frame_filename = os.path.join(frames_dir, f"frame_{j}.png") + # Convert from float32 to uint8 + frame = (frame * 255).astype(np.uint8) + Image.fromarray(frame).save(frame_filename) + + # Clean up to avoid memory issues + del generated_frames, original_video, condition_video, concatenated_frames + torch.cuda.empty_cache() + gc.collect() + + # except Exception as e: + # print(f"Error processing sample {i}: {e}") + # continue + + print(f"Validation complete. Results saved to {args.output_dir}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/infer/run_validation.sh b/infer/run_validation.sh new file mode 100644 index 0000000..21da7cb --- /dev/null +++ b/infer/run_validation.sh @@ -0,0 +1,123 @@ +# fullset !! +# pretrained_model_name_or_path="/lpai/volumes/ad-vla-vol-ga/chenantong/ckpts/CogVideoX-5b-I2V" +# checkpoint_path="../train/nuscenes-inpainting/2nodes/checkpoint-28000" +# # meta_file_path="/lpai/volumes/ad-vla-vol-ga/chenantong/datasets/NuScenes/annotation/nuScenes_val_fullset.json" +# meta_file_path="/lpai/volumes/ad-wm-vol-ga/chenantong/datasets/NuScenes/annotation/nuScenes_val_selected1k.json" +# # condition_data_root="/lpai/dataset/nuscenes-val-render-mons/0-1-0/nuscenes_vistaval_balanced1k_stride1window3_nonedit" +# # condition_data_root="/lpai/volumes/ad-vla-vol-ga/chenantong/datasets/NuScenes/renderings" +# condition_data_root="/lpai/volumes/ad-wm-vol-ga/chenantong/datasets/NuScenes/renderings_shuffletraj" +# video_data_root="/lpai/dataset/nuscenes-imgs/0-1-0/nuscenes" + +# for gpu in 0 1 2 3 4 5 6 7; do +# CUDA_VISIBLE_DEVICES=$gpu python run_validation.py \ +# --checkpoint_path $checkpoint_path \ +# --pretrained_model_name_or_path $pretrained_model_name_or_path \ +# --meta_file_path $meta_file_path \ +# --condition_data_root $condition_data_root \ +# --video_data_root $video_data_root \ +# --output_dir validation_results/2nodes_28ksteps/shuffle1k \ +# --height 480 \ +# --width 720 \ +# --max_num_frames 49 \ +# --mixed_precision bf16 \ +# --target_frames 25 \ +# --total_sections 16 \ +# --section $((gpu + 8)) & +# done + +# Evaluate different checkpoints +#steps=(10000 20000 22000 24000 26000 28000) +#for i in "${!steps[@]}"; do +# checkpoint_step=${steps[$i]} +# checkpoint_path="../train/nuscenes-inpainting/2nodes/checkpoint-${checkpoint_step}" +# meta_file_path="/lpai/volumes/ad-vla-vol-ga/chenantong/datasets/NuScenes/annotation/nuScenes_val_1k_sampled_10.json" +# +# CUDA_VISIBLE_DEVICES=$i python run_validation.py \ +# --checkpoint_path $checkpoint_path \ +# --pretrained_model_name_or_path $pretrained_model_name_or_path \ +# --meta_file_path $meta_file_path \ +# --condition_data_root $condition_data_root \ +# --video_data_root $video_data_root \ +# --output_dir validation_results/subset10/2nodes_${checkpoint_step}steps \ +# --height 480 \ +# --width 720 \ +# --max_num_frames 49 \ +# --mixed_precision bf16 \ +# --target_frames 25 \ +# --total_sections 1 \ +# --section 0 & +#done + +# pretrained_model_name_or_path="/lpai/volumes/ad-vla-vol-ga/chenantong/ckpts/CogVideoX-5b-I2V" +# checkpoint_path="../train/nuscenes-inpainting/scratch_nus_branch_nonedit/checkpoint-10000" +# meta_file_path="/lpai/volumes/ad-vla-vol-ga/chenantong/datasets/NuScenes/annotation/nuScenes_val_selected1k.json" +# condition_data_root="/lpai/dataset/nuscenes-val-render-mons/0-1-0/nuscenes_vistaval_balanced1k_stride1window3_nonedit" +# video_data_root="/lpai/dataset/nuscenes-imgs/0-1-0/nuscenes" + +# for gpu in 0 1 2 3 4 5 6 7; do +# CUDA_VISIBLE_DEVICES=$gpu python run_validation.py \ +# --checkpoint_path $checkpoint_path \ +# --pretrained_model_name_or_path $pretrained_model_name_or_path \ +# --meta_file_path $meta_file_path \ +# --condition_data_root $condition_data_root \ +# --video_data_root $video_data_root \ +# --output_dir validation_results/scratch_nus_branch_nonedit_10ksteps \ +# --height 480 \ +# --width 720 \ +# --max_num_frames 49 \ +# --mixed_precision bf16 \ +# --target_frames 25 \ +# --total_sections 8 \ +# --section $((gpu)) & +# done + +# custom traj +pretrained_model_name_or_path="THUDM/CogVideoX-5b-I2V" # or local path to downloaded CogVideoX-5b-I2V +checkpoint_path="../checkpoints/geodrive-branch" +meta_file_path="./custom_traj/custom_traj_meta.json" +condition_data_root="./custom_traj" +video_data_root="./custom_traj" + +for gpu in 0; do + CUDA_VISIBLE_DEVICES=$gpu python run_validation.py \ + --checkpoint_path $checkpoint_path \ + --pretrained_model_name_or_path $pretrained_model_name_or_path \ + --meta_file_path $meta_file_path \ + --condition_data_root $condition_data_root \ + --video_data_root $video_data_root \ + --output_dir validation_results/custom_traj \ + --height 480 \ + --width 720 \ + --conditioning_scale 1.0 \ + --max_num_frames 49 \ + --mixed_precision bf16 \ + --target_frames 49 \ + --total_sections 1 \ + --section 0 & +done + +# novel view +# pretrained_model_name_or_path="/lpai/volumes/ad-vla-vol-ga/chenantong/ckpts/CogVideoX-5b-I2V" +# checkpoint_path="../train/nuscenes-inpainting/2nodes/checkpoint-28000" +# # lora_path="../train/nuscenes-inpainting/resume10k_nuswaymo_loraonly/checkpoint-14000" +# meta_file_path="./custom_novelview/custom_traj_meta.json" +# condition_data_root="./custom_novelview" +# video_data_root="./custom_novelview" + +# for gpu in 3 4 5 6 7; do +# CUDA_VISIBLE_DEVICES=$gpu python run_validation.py \ +# --checkpoint_path $checkpoint_path \ +# --pretrained_model_name_or_path $pretrained_model_name_or_path \ +# --meta_file_path $meta_file_path \ +# --condition_data_root $condition_data_root \ +# --video_data_root $video_data_root \ +# --output_dir validation_results/2nodes_28ksteps/novelview \ +# --height 480 \ +# --width 720 \ +# --conditioning_scale 1.0 \ +# --max_num_frames 49 \ +# --mixed_precision bf16 \ +# --target_frames 25 \ +# --total_sections 5 \ +# --section $((gpu-3)) & +# done diff --git a/monst3r/README.md b/monst3r/README.md new file mode 100644 index 0000000..84c30b3 --- /dev/null +++ b/monst3r/README.md @@ -0,0 +1,173 @@ +# MonST3R: A Simple Approach for Estimating Geometry in the Presence of Motion + +**MonST3R** processes a dynamic video to produce a time-varying dynamic point cloud, along with per-frame camera poses and intrinsics, in a predominantly **feed-forward** manner. This representation then enables the efficient computation of downstream tasks, such as video depth estimation and dynamic/static scene segmentation. + +This repository is the official implementation of the paper: + +[**MonST3R: A Simple Approach for Estimating Geometry in the Presence of Motion**](https://monst3r-project.github.io/files/monst3r_paper.pdf) +[*Junyi Zhang*](https://junyi42.github.io/), +[*Charles Herrmann+*](https://scholar.google.com/citations?user=LQvi5XAAAAAJ), +[*Junhwa Hur*](https://hurjunhwa.github.io/), +[*Varun Jampani*](https://varunjampani.github.io/), +[*Trevor Darrell*](https://people.eecs.berkeley.edu/~trevor/), +[*Forrester Cole*](https://scholar.google.com/citations?user=xZRRr-IAAAAJ&hl), +[*Deqing Sun**](https://deqings.github.io/), +[*Ming-Hsuan Yang**](https://faculty.ucmerced.edu/mhyang/) +Arxiv, 2024. [**[Project Page]**](https://monst3r-project.github.io/) [**[Paper]**](https://monst3r-project.github.io/files/monst3r_paper.pdf) [**[Interactive Results🔥]**](https://monst3r-project.github.io/page1.html) + +[![Watch the video](assets/fig1_teaser.png)](https://monst3r-project.github.io/files/teaser_vid_v2_lowres.mp4) + +## TODO +- [x] Release model weights on [Google Drive](https://drive.google.com/file/d/1Z1jO_JmfZj0z3bgMvCwqfUhyZ1bIbc9E/view?usp=sharing) and [Hugging Face](https://huggingface.co/Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt) (10/07) +- [x] Release inference code for global optimization (10/18) +- [x] Release 4D visualization code (10/18) +- [x] Release training code & dataset preparation (10/19) +- [x] Release evaluation code (10/20) +- [x] Memory efficient optimization v1 (12/25): a non-batchified version of the global optimization, slower but less memory usage +- [x] Real-time reconstruction mode (1/20): a fully feed-forward mode for real-time reconstruction ([sample results](https://monst3r-paper.github.io/page0.html)) +- [ ] Gradio Demo + +## Getting Started + +### Installation + +1. Clone MonST3R. +```bash +git clone --recursive https://github.com/junyi42/monst3r +cd monst3r +## if you have already cloned monst3r: +# git submodule update --init --recursive +``` + +2. Create the environment, here we show an example using conda. +```bash +conda create -n monst3r python=3.11 cmake=3.14.0 +conda activate monst3r +conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system +pip install -r requirements.txt +# Optional: you can also install additional packages to: +# - training +# - evaluation on camera pose +# - dataset preparation +pip install -r requirements_optional.txt +``` + +3. Optional, install 4d visualization tool, `viser`. +```bash +pip install -e viser +``` + +4. Optional, compile the cuda kernels for RoPE (as in CroCo v2). +```bash +# DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime. +cd croco/models/curope/ +python setup.py build_ext --inplace +cd ../../../ +``` + +### Download Checkpoints + +We currently provide fine-tuned model weights for MonST3R, which can be downloaded on [Google Drive](https://drive.google.com/file/d/1Z1jO_JmfZj0z3bgMvCwqfUhyZ1bIbc9E/view?usp=sharing) or via [Hugging Face](https://huggingface.co/Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt). + + +To download the weights of MonST3R and optical flow models, run the following commands: +```bash +# download the weights +cd data +bash download_ckpt.sh +cd .. +``` + +### Inference + +To run the inference code, you can use the following command: +```bash +python demo.py # launch GUI, input can be a folder or a video +# use memory efficient optimization: --not_batchify +``` + +The results will be saved in the `demo_tmp/{Sequence Name}` (by default is `demo_tmp/NULL`) folder for future visualization. + +You can also run the inference code in a non-interactive mode: +```bash +python demo.py --input demo_data/lady-running --output_dir demo_tmp --seq_name lady-running +# use video as input: --input demo_data/lady-running.mp4 --num_frames 65 +# (update 12/15) use memory efficient optimization: --not_batchify +# (update 1/20) use real-time mode: --real_time +``` + +> Currently, it takes about 33G VRAM to run the inference code on a 16:9 video of 65 frames. Use less frames or disable the `flow_loss` could reduce the memory usage. We are **welcome to any PRs** to improve the memory efficiency (one reasonable way is to implement window-wise optimzation in `optimizer.py`). + +> **Update (12/15):** Using the non-batchified version of the global optimization, the memory usage has been reduced. Now, it only requires ~23G VRAM to inference on a 65-frames 16:9 video. It is now slower, but we are **welcome to any PRs** to improve the memory-speed trade-off, e.g., using a window-wise optimization in function `forward_non_batchify()` instead of edge-wise. + +> **Update (1/20):** We have added a fully feed-forward mode for the inference code, which can run in real-time. The results are worse than the global optimization mode, and it only applies to cases where the camera motion is small. More details will be updated soon. + +### Visualization + +To visualize the interactive 4D results, you can use the following command: +```bash +python viser/visualizer_monst3r.py --data demo_tmp/lady-running +# to remove the floaters of foreground: --init_conf --fg_conf_thre 1.0 (thre can be adjusted) + +# (update 1/20) for results generated by real-time mode, please (update viser and) using the following command: +python viser/visualizer_monst3r_realtime.py --data demo_tmp/lady-running +``` + +## Evaluation + +We provide here an example of joint dense reconstruction and camera pose estimation on the **DAVIS** dataset. + +First, download the dataset: +```bash +cd data; python download_davis.py; cd .. +``` + +Then, run the evaluation script: +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_pose \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=davis --output_dir="results/davis_joint" + # To use the ground truth dynamic mask for davis, add: --use_gt_mask +``` + +You could then use the `viser` to visualize the results: +```bash +python viser/visualizer_monst3r.py --data results/davis_joint/bear +# if the dynamic mask is noisy, one could visualize per-frame pointcloud by adding: --no_mask +``` + +#### For the complete scripts to evaluate the camera pose / video depth / single-frame depth estimation on the **Sintel**, **Bonn**, **KITTI**, **NYU-v2**, **TUM-dynamics**, **ScanNet**, and **DAVIS** datasets. Please refer to the [evaluation_script.md](data/evaluation_script.md) for more details. + + +## Training + +Please refer to the [prepare_training.md](data/prepare_training.md) for preparing the pretrained models and training/testing datasets. + +Then, you can train the model using the following command: +```bash +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 --master_port=29604 launch.py --mode=train \ + --train_dataset="10_000 @ PointOdysseyDUSt3R(dset='train', z_far=80, dataset_location='data/point_odyssey', S=2, aug_crop=16, resolution=[(512, 288), (512, 384), (512, 336)], transform=ColorJitter, strides=[1,2,3,4,5,6,7,8,9], dist_type='linear_1_2', aug_focal=0.9)+ 5_000 @ TarTanAirDUSt3R(dset='Hard', z_far=80, dataset_location='data/tartanair', S=2, aug_crop=16, resolution=[(512, 288), (512, 384), (512, 336)], transform=ColorJitter, strides=[1,2,3,4,5,6,7,8,9], dist_type='linear_1_2', aug_focal=0.9)+ 1_000 @ SpringDUSt3R(dset='train', z_far=80, dataset_location='data/spring', S=2, aug_crop=16, resolution=[(512, 288), (512, 384), (512, 336)], transform=ColorJitter, strides=[1,2,3,4,5,6,7,8,9], dist_type='linear_1_2', aug_focal=0.9)+ 4_000 @ Waymo(ROOT='data/waymo_processed', pairs_npz_name='waymo_pairs_video.npz', aug_crop=16, resolution=[(512, 288), (512, 384), (512, 336)], transform=ColorJitter, aug_focal=0.9)" \ + --test_dataset="1000 @ PointOdysseyDUSt3R(dset='test', z_far=80, dataset_location='data/point_odyssey', S=2, strides=[1,2,3,4,5,6,7,8,9], resolution=[(512, 288)], seed=777)+ 1000 @ SintelDUSt3R(dset='final', z_far=80, S=2, strides=[1,2,3,4,5,6,7,8,9], resolution=[(512, 224)], seed=777)" \ + --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ + --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \ + --pretrained="checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" \ + --lr=0.00005 --min_lr=1e-06 --warmup_epochs=3 --epochs=50 --batch_size=4 --accum_iter=4 \ + --save_freq=3 --keep_freq=5 --eval_freq=1 \ + --output_dir="results/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt" +``` + +## Citation + +If you find our work useful, please cite: + +```bibtex +@article{zhang2024monst3r, + author = {Zhang, Junyi and Herrmann, Charles and Hur, Junhwa and Jampani, Varun and Darrell, Trevor and Cole, Forrester and Sun, Deqing and Yang, Ming-Hsuan}, + title = {MonST3R: A Simple Approach for Estimating Geometry in the Presence of Motion}, + journal = {arXiv preprint arxiv:2410.03825}, + year = {2024} +} +``` + +## Acknowledgements +Our code is based on [DUSt3R](https://github.com/naver/dust3r) and [CasualSAM](https://github.com/ztzhang/casualSAM), our camera pose estimation evaluation script is based on [LEAP-VO](https://github.com/chiaki530/leapvo), and our visualization code is based on [Viser](https://github.com/nerfstudio-project/viser). We thank the authors for their excellent work! diff --git a/monst3r/croco/.gitignore b/monst3r/croco/.gitignore new file mode 100644 index 0000000..b6e4761 --- /dev/null +++ b/monst3r/croco/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/monst3r/croco/LICENSE b/monst3r/croco/LICENSE new file mode 100644 index 0000000..d9b84b1 --- /dev/null +++ b/monst3r/croco/LICENSE @@ -0,0 +1,52 @@ +CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. + +A summary of the CC BY-NC-SA 4.0 license is located here: + https://creativecommons.org/licenses/by-nc-sa/4.0/ + +The CC BY-NC-SA 4.0 license is located here: + https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode + + +SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py + +*************************** + +NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py + +This software is being redistributed in a modifiled form. The original form is available here: + +https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +This software in this file incorporates parts of the following software available here: + +Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py +available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE + +MoCo v3: https://github.com/facebookresearch/moco-v3 +available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE + +DeiT: https://github.com/facebookresearch/deit +available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE + + +ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: + +https://github.com/facebookresearch/mae/blob/main/LICENSE + +Attribution-NonCommercial 4.0 International + +*************************** + +NOTICE WITH RESPECT TO THE FILE: models/blocks.py + +This software is being redistributed in a modifiled form. The original form is available here: + +https://github.com/rwightman/pytorch-image-models + +ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: + +https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ \ No newline at end of file diff --git a/monst3r/croco/NOTICE b/monst3r/croco/NOTICE new file mode 100644 index 0000000..d51bb36 --- /dev/null +++ b/monst3r/croco/NOTICE @@ -0,0 +1,21 @@ +CroCo +Copyright 2022-present NAVER Corp. + +This project contains subcomponents with separate copyright notices and license terms. +Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. + +==== + +facebookresearch/mae +https://github.com/facebookresearch/mae + +Attribution-NonCommercial 4.0 International + +==== + +rwightman/pytorch-image-models +https://github.com/rwightman/pytorch-image-models + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ \ No newline at end of file diff --git a/monst3r/croco/README.MD b/monst3r/croco/README.MD new file mode 100644 index 0000000..38e33b0 --- /dev/null +++ b/monst3r/croco/README.MD @@ -0,0 +1,124 @@ +# CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow + +[[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)] + +This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2: + +![image](assets/arch.jpg) + +```bibtex +@inproceedings{croco, + title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}}, + author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}}, + booktitle={{NeurIPS}}, + year={2022} +} + +@inproceedings{croco_v2, + title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}}, + author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me}, + booktitle={ICCV}, + year={2023} +} +``` + +## License + +The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information. +Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License. +Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license. + +## Preparation + +1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version. + +```bash +conda create -n croco python=3.7 cmake=3.14.0 +conda activate croco +conda install habitat-sim headless -c conda-forge -c aihabitat +conda install pytorch torchvision -c pytorch +conda install notebook ipykernel matplotlib +conda install ipywidgets widgetsnbextension +conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation + +``` + +2. Compile cuda kernels for RoPE + +CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels. +```bash +cd models/curope/ +python setup.py build_ext --inplace +cd ../../ +``` + +This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only. +You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation. + +In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded. + +3. Download pre-trained model + +We provide several pre-trained models: + +| modelname | pre-training data | pos. embed. | Encoder | Decoder | +|------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------| +| [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small | +| [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small | +| [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base | +| [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base | + +To download a specific model, i.e., the first one (`CroCo.pth`) +```bash +mkdir -p pretrained_models/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/ +``` + +## Reconstruction example + +Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`) +```bash +python demo.py +``` + +## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator + +First download the test scene from Habitat: +```bash +python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/ +``` + +Then, run the Notebook demo `interactive_demo.ipynb`. + +In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo. +![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg) + +## Pre-training + +### CroCo + +To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command: +``` +torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/ +``` + +Our CroCo pre-training was launched on a single server with 4 GPUs. +It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training. +Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case. +The first run can take a few minutes to start, to parse all available pre-training pairs. + +### CroCo v2 + +For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD). +Then, run the following command for the largest model (ViT-L encoder, Base decoder): +``` +torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/ +``` + +Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases. +The largest model should take around 12 days on A100. +Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case. + +## Stereo matching and Optical flow downstream tasks + +For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD). diff --git a/monst3r/croco/croco-stereo-flow-demo.ipynb b/monst3r/croco/croco-stereo-flow-demo.ipynb new file mode 100644 index 0000000..2b00a76 --- /dev/null +++ b/monst3r/croco/croco-stereo-flow-demo.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9bca0f41", + "metadata": {}, + "source": [ + "# Simple inference example with CroCo-Stereo or CroCo-Flow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80653ef7", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n", + "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)." + ] + }, + { + "cell_type": "markdown", + "id": "4f033862", + "metadata": {}, + "source": [ + "First download the model(s) of your choice by running\n", + "```\n", + "bash stereoflow/download_model.sh crocostereo.pth\n", + "bash stereoflow/download_model.sh crocoflow.pth\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fb2e392", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", + "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", + "import matplotlib.pylab as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0e25d77", + "metadata": {}, + "outputs": [], + "source": [ + "from stereoflow.test import _load_model_and_criterion\n", + "from stereoflow.engine import tiled_pred\n", + "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n", + "from stereoflow.datasets_flow import flowToColor\n", + "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower" + ] + }, + { + "cell_type": "markdown", + "id": "86a921f5", + "metadata": {}, + "source": [ + "### CroCo-Stereo example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64e483cb", + "metadata": {}, + "outputs": [], + "source": [ + "image1 = np.asarray(Image.open(''))\n", + "image2 = np.asarray(Image.open(''))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0d04303", + "metadata": {}, + "outputs": [], + "source": [ + "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47dc14b5", + "metadata": {}, + "outputs": [], + "source": [ + "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", + "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", + "with torch.inference_mode():\n", + " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", + "pred = pred.squeeze(0).squeeze(0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "583b9f16", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(vis_disparity(pred))\n", + "plt.axis('off')" + ] + }, + { + "cell_type": "markdown", + "id": "d2df5d70", + "metadata": {}, + "source": [ + "### CroCo-Flow example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ee257a7", + "metadata": {}, + "outputs": [], + "source": [ + "image1 = np.asarray(Image.open(''))\n", + "image2 = np.asarray(Image.open(''))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5edccf0", + "metadata": {}, + "outputs": [], + "source": [ + "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b19692c3", + "metadata": {}, + "outputs": [], + "source": [ + "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", + "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", + "with torch.inference_mode():\n", + " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", + "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26f79db3", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(flowToColor(pred))\n", + "plt.axis('off')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/monst3r/croco/datasets/__init__.py b/monst3r/croco/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/monst3r/croco/datasets/crops/README.MD b/monst3r/croco/datasets/crops/README.MD new file mode 100644 index 0000000..47ddabe --- /dev/null +++ b/monst3r/croco/datasets/crops/README.MD @@ -0,0 +1,104 @@ +## Generation of crops from the real datasets + +The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL. + +### Download the metadata of the crops to generate + +First, download the metadata and put them in `./data/`: +``` +mkdir -p data +cd data/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip +unzip crop_metadata.zip +rm crop_metadata.zip +cd .. +``` + +### Prepare the original datasets + +Second, download the original datasets in `./data/original_datasets/`. +``` +mkdir -p data/original_datasets +``` + +##### ARKitScenes + +Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`. +The resulting file structure should be like: +``` +./data/original_datasets/ARKitScenes/ +└───Training + └───40753679 + │ │ ultrawide + │ │ ... + └───40753686 + │ + ... +``` + +##### MegaDepth + +Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`. +The resulting file structure should be like: + +``` +./data/original_datasets/MegaDepth/ +└───0000 +│ └───images +│ │ │ 1000557903_87fa96b8a4_o.jpg +│ │ └ ... +│ └─── ... +└───0001 +│ │ +│ └ ... +└─── ... +``` + +##### 3DStreetView + +Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`. +The resulting file structure should be like: + +``` +./data/original_datasets/3DStreetView/ +└───dataset_aligned +│ └───0002 +│ │ │ 0000002_0000001_0000002_0000001.jpg +│ │ └ ... +│ └─── ... +└───dataset_unaligned +│ └───0003 +│ │ │ 0000003_0000001_0000002_0000001.jpg +│ │ └ ... +│ └─── ... +``` + +##### IndoorVL + +Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture). + +``` +pip install kapture +mkdir -p ./data/original_datasets/IndoorVL +cd ./data/original_datasets/IndoorVL +kapture_download_dataset.py update +kapture_download_dataset.py install "HyundaiDepartmentStore_*" +kapture_download_dataset.py install "GangnamStation_*" +cd - +``` + +### Extract the crops + +Now, extract the crops for each of the dataset: +``` +for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL; +do + python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500; +done +``` + +##### Note for IndoorVL + +Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper. +To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively. +The impact on the performance is negligible. diff --git a/monst3r/croco/datasets/crops/extract_crops_from_images.py b/monst3r/croco/datasets/crops/extract_crops_from_images.py new file mode 100644 index 0000000..eb66a04 --- /dev/null +++ b/monst3r/croco/datasets/crops/extract_crops_from_images.py @@ -0,0 +1,159 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Extracting crops for pre-training +# -------------------------------------------------------- + +import os +import argparse +from tqdm import tqdm +from PIL import Image +import functools +from multiprocessing import Pool +import math + + +def arg_parser(): + parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list') + + parser.add_argument('--crops', type=str, required=True, help='crop file') + parser.add_argument('--root-dir', type=str, required=True, help='root directory') + parser.add_argument('--output-dir', type=str, required=True, help='output directory') + parser.add_argument('--imsize', type=int, default=256, help='size of the crops') + parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads') + parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories') + parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir') + return parser + + +def main(args): + listing_path = os.path.join(args.output_dir, 'listing.txt') + + print(f'Loading list of crops ... ({args.nthread} threads)') + crops, num_crops_to_generate = load_crop_file(args.crops) + + print(f'Preparing jobs ({len(crops)} candidate image pairs)...') + num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels) + num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels)) + + jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) + del crops + + os.makedirs(args.output_dir, exist_ok=True) + mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map + call = functools.partial(save_image_crops, args) + + print(f"Generating cropped images to {args.output_dir} ...") + with open(listing_path, 'w') as listing: + listing.write('# pair_path\n') + for results in tqdm(mmap(call, jobs), total=len(jobs)): + for path in results: + listing.write(f'{path}\n') + print('Finished writing listing to', listing_path) + + +def load_crop_file(path): + data = open(path).read().splitlines() + pairs = [] + num_crops_to_generate = 0 + for line in tqdm(data): + if line.startswith('#'): + continue + line = line.split(', ') + if len(line) < 8: + img1, img2, rotation = line + pairs.append((img1, img2, int(rotation), [])) + else: + l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) + rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) + pairs[-1][-1].append((rect1, rect2)) + num_crops_to_generate += 1 + return pairs, num_crops_to_generate + + +def prepare_jobs(pairs, num_levels, num_pairs_in_dir): + jobs = [] + powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))] + + def get_path(idx): + idx_array = [] + d = idx + for level in range(num_levels - 1): + idx_array.append(idx // powers[level]) + idx = idx % powers[level] + idx_array.append(d) + return '/'.join(map(lambda x: hex(x)[2:], idx_array)) + + idx = 0 + for pair_data in tqdm(pairs): + img1, img2, rotation, crops = pair_data + if -60 <= rotation and rotation <= 60: + rotation = 0 # most likely not a true rotation + paths = [get_path(idx + k) for k in range(len(crops))] + idx += len(crops) + jobs.append(((img1, img2), rotation, crops, paths)) + return jobs + + +def load_image(path): + try: + return Image.open(path).convert('RGB') + except Exception as e: + print('skipping', path, e) + raise OSError() + + +def save_image_crops(args, data): + # load images + img_pair, rot, crops, paths = data + try: + img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair] + except OSError as e: + return [] + + def area(sz): + return sz[0] * sz[1] + + tgt_size = (args.imsize, args.imsize) + + def prepare_crop(img, rect, rot=0): + # actual crop + img = img.crop(rect) + + # resize to desired size + interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC + img = img.resize(tgt_size, resample=interp) + + # rotate the image + rot90 = (round(rot/90) % 4) * 90 + if rot90 == 90: + img = img.transpose(Image.Transpose.ROTATE_90) + elif rot90 == 180: + img = img.transpose(Image.Transpose.ROTATE_180) + elif rot90 == 270: + img = img.transpose(Image.Transpose.ROTATE_270) + return img + + results = [] + for (rect1, rect2), path in zip(crops, paths): + crop1 = prepare_crop(img1, rect1) + crop2 = prepare_crop(img2, rect2, rot) + + fullpath1 = os.path.join(args.output_dir, path+'_1.jpg') + fullpath2 = os.path.join(args.output_dir, path+'_2.jpg') + os.makedirs(os.path.dirname(fullpath1), exist_ok=True) + + assert not os.path.isfile(fullpath1), fullpath1 + assert not os.path.isfile(fullpath2), fullpath2 + crop1.save(fullpath1) + crop2.save(fullpath2) + results.append(path) + + return results + + +if __name__ == '__main__': + args = arg_parser().parse_args() + main(args) + diff --git a/monst3r/croco/datasets/habitat_sim/README.MD b/monst3r/croco/datasets/habitat_sim/README.MD new file mode 100644 index 0000000..a505781 --- /dev/null +++ b/monst3r/croco/datasets/habitat_sim/README.MD @@ -0,0 +1,76 @@ +## Generation of synthetic image pairs using Habitat-Sim + +These instructions allow to generate pre-training pairs from the Habitat simulator. +As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent. + +### Download Habitat-Sim scenes +Download Habitat-Sim scenes: +- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md +- We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets. +- Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`. +``` +./data/ +└──habitat-sim-data/ + └──scene_datasets/ + ├──hm3d/ + ├──gibson/ + ├──habitat-test-scenes/ + ├──replica_cad_baked_lighting/ + ├──replica_cad/ + ├──ReplicaDataset/ + └──scannet/ +``` + +### Image pairs generation +We provide metadata to generate reproducible images pairs for pretraining and validation. +Experiments described in the paper used similar data, but whose generation was not reproducible at the time. + +Specifications: +- 256x256 resolution images, with 60 degrees field of view . +- Up to 1000 image pairs per scene. +- Number of scenes considered/number of images pairs per dataset: + - Scannet: 1097 scenes / 985 209 pairs + - HM3D: + - hm3d/train: 800 / 800k pairs + - hm3d/val: 100 scenes / 100k pairs + - hm3d/minival: 10 scenes / 10k pairs + - habitat-test-scenes: 3 scenes / 3k pairs + - replica_cad_baked_lighting: 13 scenes / 13k pairs + +- Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes. + +Download metadata and extract it: +```bash +mkdir -p data/habitat_release_metadata/ +cd data/habitat_release_metadata/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz +tar -xvf multiview_habitat_metadata.tar.gz +cd ../.. +# Location of the metadata +METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata" +``` + +Generate image pairs from metadata: +- The following command will print a list of commandlines to generate image pairs for each scene: +```bash +# Target output directory +PAIRS_DATASET_DIR="./data/habitat_release/" +python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR +``` +- One can launch multiple of such commands in parallel e.g. using GNU Parallel: +```bash +python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16 +``` + +## Metadata generation + +Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible: +```bash +# Print commandlines to generate image pairs from the different scenes available. +PAIRS_DATASET_DIR=MY_CUSTOM_PATH +python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR + +# Once a dataset is generated, pack metadata files for reproducibility. +METADATA_DIR=MY_CUSTON_PATH +python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR +``` diff --git a/monst3r/croco/datasets/habitat_sim/__init__.py b/monst3r/croco/datasets/habitat_sim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/monst3r/croco/datasets/habitat_sim/generate_from_metadata.py b/monst3r/croco/datasets/habitat_sim/generate_from_metadata.py new file mode 100644 index 0000000..fbe0d39 --- /dev/null +++ b/monst3r/croco/datasets/habitat_sim/generate_from_metadata.py @@ -0,0 +1,92 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script to generate image pairs for a given scene reproducing poses provided in a metadata file. +""" +import os +from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator +from datasets.habitat_sim.paths import SCENES_DATASET +import argparse +import quaternion +import PIL.Image +import cv2 +import json +from tqdm import tqdm + +def generate_multiview_images_from_metadata(metadata_filename, + output_dir, + overload_params = dict(), + scene_datasets_paths=None, + exist_ok=False): + """ + Generate images from a metadata file for reproducibility purposes. + """ + # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label + if scene_datasets_paths is not None: + scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True)) + + with open(metadata_filename, 'r') as f: + input_metadata = json.load(f) + metadata = dict() + for key, value in input_metadata.items(): + # Optionally replace some paths + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + if scene_datasets_paths is not None: + for dataset_label, dataset_path in scene_datasets_paths.items(): + if value.startswith(dataset_label): + value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label))) + break + metadata[key] = value + + # Overload some parameters + for key, value in overload_params.items(): + metadata[key] = value + + generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))]) + generate_depth = metadata["generate_depth"] + + os.makedirs(output_dir, exist_ok=exist_ok) + + generator = MultiviewHabitatSimGenerator(**generation_entries) + + # Generate views + for idx_label, data in tqdm(metadata['multiviews'].items()): + positions = data["positions"] + orientations = data["orientations"] + n = len(positions) + for oidx in range(n): + observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx])) + observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 + # Color image saved using PIL + img = PIL.Image.fromarray(observation['color'][:,:,:3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr") + cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json") + with open(filename, "w") as f: + json.dump(camera_params, f) + # Save metadata + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(metadata, f) + + generator.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_filename", required=True) + parser.add_argument("--output_dir", required=True) + args = parser.parse_args() + + generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename, + output_dir=args.output_dir, + scene_datasets_paths=SCENES_DATASET, + overload_params=dict(), + exist_ok=True) + + \ No newline at end of file diff --git a/monst3r/croco/datasets/habitat_sim/generate_from_metadata_files.py b/monst3r/croco/datasets/habitat_sim/generate_from_metadata_files.py new file mode 100644 index 0000000..962ef84 --- /dev/null +++ b/monst3r/croco/datasets/habitat_sim/generate_from_metadata_files.py @@ -0,0 +1,27 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script generating commandlines to generate image pairs from metadata files. +""" +import os +import glob +from tqdm import tqdm +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.") + args = parser.parse_args() + + input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True) + + for metadata_filename in tqdm(input_metadata_filenames): + output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir)) + # Do not process the scene if the metadata file already exists + if os.path.exists(os.path.join(output_dir, "metadata.json")): + continue + commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}" + print(commandline) diff --git a/monst3r/croco/datasets/habitat_sim/generate_multiview_images.py b/monst3r/croco/datasets/habitat_sim/generate_multiview_images.py new file mode 100644 index 0000000..421d49a --- /dev/null +++ b/monst3r/croco/datasets/habitat_sim/generate_multiview_images.py @@ -0,0 +1,177 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +from tqdm import tqdm +import argparse +import PIL.Image +import numpy as np +import json +from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError +from datasets.habitat_sim.paths import list_scenes_available +import cv2 +import quaternion +import shutil + +def generate_multiview_images_for_scene(scene_dataset_config_file, + scene, + navmesh, + output_dir, + views_count, + size, + exist_ok=False, + generate_depth=False, + **kwargs): + """ + Generate tuples of overlapping views for a given scene. + generate_depth: generate depth images and camera parameters. + """ + if os.path.exists(output_dir) and not exist_ok: + print(f"Scene {scene}: data already generated. Ignoring generation.") + return + try: + print(f"Scene {scene}: {size} multiview acquisitions to generate...") + os.makedirs(output_dir, exist_ok=exist_ok) + + metadata_filename = os.path.join(output_dir, "metadata.json") + + metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count=views_count, + size=size, + generate_depth=generate_depth, + **kwargs) + metadata_template["multiviews"] = dict() + + if os.path.exists(metadata_filename): + print("Metadata file already exists:", metadata_filename) + print("Loading already generated metadata file...") + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + for key in metadata_template.keys(): + if key != "multiviews": + assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}." + else: + print("No temporary file found. Starting generation from scratch...") + metadata = metadata_template + + starting_id = len(metadata["multiviews"]) + print(f"Starting generation from index {starting_id}/{size}...") + if starting_id >= size: + print("Generation already done.") + return + + generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count = views_count, + size = size, + **kwargs) + + for idx in tqdm(range(starting_id, size)): + # Generate / re-generate the observations + try: + data = generator[idx] + observations = data["observations"] + positions = data["positions"] + orientations = data["orientations"] + + idx_label = f"{idx:08}" + for oidx, observation in enumerate(observations): + observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 + # Color image saved using PIL + img = PIL.Image.fromarray(observation['color'][:,:,:3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr") + cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json") + with open(filename, "w") as f: + json.dump(camera_params, f) + metadata["multiviews"][idx_label] = {"positions": positions.tolist(), + "orientations": orientations.tolist(), + "covisibility_ratios": data["covisibility_ratios"].tolist(), + "valid_fractions": data["valid_fractions"].tolist(), + "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()} + except RecursionError: + print("Recursion error: unable to sample observations for this scene. We will stop there.") + break + + # Regularly save a temporary metadata file, in case we need to restart the generation + if idx % 10 == 0: + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + # Save metadata + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + generator.close() + except NoNaviguableSpaceError: + pass + +def create_commandline(scene_data, generate_depth, exist_ok=False): + """ + Create a commandline string to generate a scene. + """ + def my_formatting(val): + if val is None or val == "": + return '""' + else: + return val + commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)} + --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)} + --navmesh {my_formatting(scene_data.navmesh)} + --output_dir {my_formatting(scene_data.output_dir)} + --generate_depth {int(generate_depth)} + --exist_ok {int(exist_ok)} + """ + commandline = " ".join(commandline.split()) + return commandline + +if __name__ == "__main__": + os.umask(2) + + parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available: + > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands + """) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true") + parser.add_argument("--scene", type=str, default="") + parser.add_argument("--scene_dataset_config_file", type=str, default="") + parser.add_argument("--navmesh", type=str, default="") + + parser.add_argument("--generate_depth", type=int, default=1) + parser.add_argument("--exist_ok", type=int, default=0) + + kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000) + + args = parser.parse_args() + generate_depth=bool(args.generate_depth) + exist_ok = bool(args.exist_ok) + + if args.list_commands: + # Listing scenes available... + scenes_data = list_scenes_available(base_output_dir=args.output_dir) + + for scene_data in scenes_data: + print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok)) + else: + if args.scene == "" or args.output_dir == "": + print("Missing scene or output dir argument!") + print(parser.format_help()) + else: + generate_multiview_images_for_scene(scene=args.scene, + scene_dataset_config_file = args.scene_dataset_config_file, + navmesh = args.navmesh, + output_dir = args.output_dir, + exist_ok=exist_ok, + generate_depth=generate_depth, + **kwargs) \ No newline at end of file diff --git a/monst3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py b/monst3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py new file mode 100644 index 0000000..91e5f92 --- /dev/null +++ b/monst3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py @@ -0,0 +1,390 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +import numpy as np +import quaternion +import habitat_sim +import json +from sklearn.neighbors import NearestNeighbors +import cv2 + +# OpenCV to habitat camera convention transformation +R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0) +R_HABITAT2OPENCV = R_OPENCV2HABITAT.T +DEG2RAD = np.pi / 180 + +def compute_camera_intrinsics(height, width, hfov): + f = width/2 / np.tan(hfov/2 * np.pi/180) + cu, cv = width/2, height/2 + return f, cu, cv + +def compute_camera_pose_opencv_convention(camera_position, camera_orientation): + R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT + t_cam2world = np.asarray(camera_position) + return R_cam2world, t_cam2world + +def compute_pointmap(depthmap, hfov): + """ Compute a HxWx3 pointmap in camera frame from a HxW depth map.""" + height, width = depthmap.shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + # Cast depth map to point + z_cam = depthmap + u, v = np.meshgrid(range(width), range(height)) + x_cam = (u - cu) / f * z_cam + y_cam = (v - cv) / f * z_cam + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1) + return X_cam + +def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation): + """Return a 3D point cloud corresponding to valid pixels of the depth map""" + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation) + + X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov) + valid_mask = (X_cam[:,:,2] != 0.0) + + X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()] + X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3) + return X_world + +def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False): + """ + Compute 'overlapping' metrics based on a distance threshold between two point clouds. + """ + nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2) + distances, indices = nbrs.kneighbors(pointcloud1) + intersection1 = np.count_nonzero(distances.flatten() < distance_threshold) + + data = {"intersection1": intersection1, + "size1": len(pointcloud1)} + if compute_symmetric: + nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1) + distances, indices = nbrs.kneighbors(pointcloud2) + intersection2 = np.count_nonzero(distances.flatten() < distance_threshold) + data["intersection2"] = intersection2 + data["size2"] = len(pointcloud2) + + return data + +def _append_camera_parameters(observation, hfov, camera_location, camera_rotation): + """ + Add camera parameters to the observation dictionnary produced by Habitat-Sim + In-place modifications. + """ + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation) + height, width = observation['depth'].shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + K = np.asarray([[f, 0, cu], + [0, f, cv], + [0, 0, 1.0]]) + observation["camera_intrinsics"] = K + observation["t_cam2world"] = t_cam2world + observation["R_cam2world"] = R_cam2world + +def look_at(eye, center, up, return_cam2world=True): + """ + Return camera pose looking at a given center point. + Analogous of gluLookAt function, using OpenCV camera convention. + """ + z = center - eye + z /= np.linalg.norm(z, axis=-1, keepdims=True) + y = -up + y = y - np.sum(y * z, axis=-1, keepdims=True) * z + y /= np.linalg.norm(y, axis=-1, keepdims=True) + x = np.cross(y, z, axis=-1) + + if return_cam2world: + R = np.stack((x, y, z), axis=-1) + t = eye + else: + # World to camera transformation + # Transposed matrix + R = np.stack((x, y, z), axis=-2) + t = - np.einsum('...ij, ...j', R, eye) + return R, t + +def look_at_for_habitat(eye, center, up, return_cam2world=True): + R, t = look_at(eye, center, up) + orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T) + return orientation, t + +def generate_orientation_noise(pan_range, tilt_range, roll_range): + return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP) + * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT) + * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT)) + + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + +class MultiviewHabitatSimGenerator: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + resolution = (240, 320), + views_count=2, + hfov = 60, + gpu_id = 0, + size = 10000, + minimum_covisibility = 0.5, + transform = None): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.resolution = resolution + self.views_count = views_count + assert(self.views_count >= 1) + self.hfov = hfov + self.gpu_id = gpu_id + self.size = size + self.transform = transform + + # Noise added to camera orientation + self.pan_range = (-3, 3) + self.tilt_range = (-10, 10) + self.roll_range = (-5, 5) + + # Height range to sample cameras + self.height_range = (1.2, 1.8) + + # Random steps between the camera views + self.random_steps_count = 5 + self.random_step_variance = 2.0 + + # Minimum fraction of the scene which should be valid (well defined depth) + self.minimum_valid_fraction = 0.7 + + # Distance threshold to see to select pairs + self.distance_threshold = 0.05 + # Minimum IoU of a view point cloud with respect to the reference view to be kept. + self.minimum_covisibility = minimum_covisibility + + # Maximum number of retries. + self.max_attempts_count = 100 + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly + if self.seed == None: + # Re-seed numpy generator + np.random.seed() + self.seed = np.random.randint(2**32-1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "": + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = "depth" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.resolution + depth_sensor_spec.hfov = self.hfov + depth_sensor_spec.position = [0.0, 0.0, 0] + depth_sensor_spec.orientation + + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = "color" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.resolution + rgb_sensor_spec.hfov = self.hfov + rgb_sensor_spec.position = [0.0, 0.0, 0] + agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec]) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + # Use pre-computed navmesh when available (usually better than those generated automatically) + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + if not self.sim.pathfinder.is_loaded: + # Try to compute a navmesh + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + # Ensure that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})") + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + self.sim.close() + + def __del__(self): + self.sim.close() + + def __len__(self): + return self.size + + def sample_random_viewpoint(self): + """ Sample a random viewpoint using the navmesh """ + nav_point = self.sim.pathfinder.get_random_navigable_point() + + # Sample a random viewpoint height + viewpoint_height = np.random.uniform(*self.height_range) + viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP + viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return viewpoint_position, viewpoint_orientation, nav_point + + def sample_other_random_viewpoint(self, observed_point, nav_point): + """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point.""" + other_nav_point = nav_point + + walk_directions = self.random_step_variance * np.asarray([1,0,1]) + for i in range(self.random_steps_count): + temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3)) + # Snapping may return nan when it fails + if not np.isnan(temp[0]): + other_nav_point = temp + + other_viewpoint_height = np.random.uniform(*self.height_range) + other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP + + # Set viewing direction towards the central point + rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True) + rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return position, rotation, other_nav_point + + def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud): + """ Check if a viewpoint is valid and overlaps significantly with a reference one. """ + # Observation + pixels_count = self.resolution[0] * self.resolution[1] + valid_fraction = len(other_pointcloud) / pixels_count + assert valid_fraction <= 1.0 and valid_fraction >= 0.0 + overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True) + covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count) + is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility) + return is_valid, valid_fraction, covisibility + + def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation): + """ Check if a viewpoint is valid and overlaps significantly with a reference one. """ + # Observation + other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation) + return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + + def render_viewpoint(self, viewpoint_position, viewpoint_orientation): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation) + return viewpoint_observations + + def __getitem__(self, useless_idx): + ref_position, ref_orientation, nav_point = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + # Extract point cloud + ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov, + camera_position=ref_position, camera_rotation=ref_orientation) + + pixels_count = self.resolution[0] * self.resolution[1] + ref_valid_fraction = len(ref_pointcloud) / pixels_count + assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0 + if ref_valid_fraction < self.minimum_valid_fraction: + # This should produce a recursion error at some point when something is very wrong. + return self[0] + # Pick an reference observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + + # Add the first image as reference + viewpoints_observations = [ref_observations] + viewpoints_covisibility = [ref_valid_fraction] + viewpoints_positions = [ref_position] + viewpoints_orientations = [quaternion.as_float_array(ref_orientation)] + viewpoints_clouds = [ref_pointcloud] + viewpoints_valid_fractions = [ref_valid_fraction] + + for _ in range(self.views_count - 1): + # Generate an other viewpoint using some dummy random walk + successful_sampling = False + for sampling_attempt in range(self.max_attempts_count): + position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point) + # Observation + other_viewpoint_observations = self.render_viewpoint(position, rotation) + other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation) + + is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + if is_valid: + successful_sampling = True + break + if not successful_sampling: + print("WARNING: Maximum number of attempts reached.") + # Dirty hack, try using a novel original viewpoint + return self[0] + viewpoints_observations.append(other_viewpoint_observations) + viewpoints_covisibility.append(covisibility) + viewpoints_positions.append(position) + viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding. + viewpoints_clouds.append(other_pointcloud) + viewpoints_valid_fractions.append(valid_fraction) + + # Estimate relations between all pairs of images + pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations))) + for i in range(len(viewpoints_observations)): + pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i] + for j in range(i+1, len(viewpoints_observations)): + overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True) + pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count + pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count + + # IoU is relative to the image 0 + data = {"observations": viewpoints_observations, + "positions": np.asarray(viewpoints_positions), + "orientations": np.asarray(viewpoints_orientations), + "covisibility_ratios": np.asarray(viewpoints_covisibility), + "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float), + "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float), + } + + if self.transform is not None: + data = self.transform(data) + return data + + def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False): + """ + Return a list of images corresponding to a spiral trajectory from a random starting point. + Useful to generate nice visualisations. + Use an even number of half turns to get a nice "C1-continuous" loop effect + """ + ref_position, ref_orientation, navpoint = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov, + camera_position=ref_position, camera_rotation=ref_orientation) + pixels_count = self.resolution[0] * self.resolution[1] + if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction: + # Dirty hack: ensure that the valid part of the image is significant + return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation) + + # Pick an observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation) + + images = [] + is_valid = [] + # Spiral trajectory, use_constant orientation + for i, alpha in enumerate(np.linspace(0, 1, images_count)): + r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius + theta = alpha * half_turns * np.pi + x = r * np.cos(theta) + y = r * np.sin(theta) + z = 0.0 + position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten() + if use_constant_orientation: + orientation = ref_orientation + else: + # trajectory looking at a mean point in front of the ref observation + orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP) + observations = self.render_viewpoint(position, orientation) + images.append(observations['color'][...,:3]) + _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation) + is_valid.append(_is_valid) + return images, np.all(is_valid) \ No newline at end of file diff --git a/monst3r/croco/datasets/habitat_sim/pack_metadata_files.py b/monst3r/croco/datasets/habitat_sim/pack_metadata_files.py new file mode 100644 index 0000000..10672a0 --- /dev/null +++ b/monst3r/croco/datasets/habitat_sim/pack_metadata_files.py @@ -0,0 +1,69 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +""" +Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere. +""" +import os +import glob +from tqdm import tqdm +import shutil +import json +from datasets.habitat_sim.paths import * +import argparse +import collections + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("output_dir") + args = parser.parse_args() + + input_dirname = args.input_dir + output_dirname = args.output_dir + + input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True) + + images_count = collections.defaultdict(lambda : 0) + + os.makedirs(output_dirname) + for input_filename in tqdm(input_metadata_filenames): + # Ignore empty files + with open(input_filename, "r") as f: + original_metadata = json.load(f) + if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0: + print("No views in", input_filename) + continue + + relpath = os.path.relpath(input_filename, input_dirname) + print(relpath) + + # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability. + # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern. + scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True)) + metadata = dict() + for key, value in original_metadata.items(): + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + known_path = False + for dataset, dataset_path in scenes_dataset_paths.items(): + if value.startswith(dataset_path): + value = os.path.join(dataset, os.path.relpath(value, dataset_path)) + known_path = True + break + if not known_path: + raise KeyError("Unknown path:" + value) + metadata[key] = value + + # Compile some general statistics while packing data + scene_split = metadata["scene"].split("/") + upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0] + images_count[upper_level] += len(metadata["multiviews"]) + + output_filename = os.path.join(output_dirname, relpath) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + with open(output_filename, "w") as f: + json.dump(metadata, f) + + # Print statistics + print("Images count:") + for upper_level, count in images_count.items(): + print(f"- {upper_level}: {count}") \ No newline at end of file diff --git a/monst3r/croco/datasets/habitat_sim/paths.py b/monst3r/croco/datasets/habitat_sim/paths.py new file mode 100644 index 0000000..4d63b5f --- /dev/null +++ b/monst3r/croco/datasets/habitat_sim/paths.py @@ -0,0 +1,129 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Paths to Habitat-Sim scenes +""" + +import os +import json +import collections +from tqdm import tqdm + + +# Hardcoded path to the different scene datasets +SCENES_DATASET = { + "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/", + "gibson": "./data/habitat-sim-data/scene_datasets/gibson/", + "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/", + "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/", + "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/", + "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/", + "scannet": "./data/habitat-sim/scene_datasets/scannet/" +} + +SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"]) + +def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]): + scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json") + scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"] + navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx]) + # Add scene + data = SceneData(scene_dataset_config_file=scene_dataset_config_file, + scene = scenes[idx] + ".scene_instance.json", + navmesh = os.path.join(base_path, navmeshes[idx]), + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + +def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]): + scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json") + scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], []) + navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx]) + data = SceneData(scene_dataset_config_file=scene_dataset_config_file, + scene = scenes[idx], + navmesh = "", + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + +def list_replica_scenes(base_output_dir, base_path): + scenes_data = [] + for scene_id in os.listdir(base_path): + scene = os.path.join(base_path, scene_id, "mesh.ply") + navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it + scene_dataset_config_file = "" + output_dir = os.path.join(base_output_dir, scene_id) + # Add scene only if it does not exist already, or if exist_ok + data = SceneData(scene_dataset_config_file = scene_dataset_config_file, + scene = scene, + navmesh = navmesh, + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + + +def list_scenes(base_output_dir, base_path): + """ + Generic method iterating through a base_path folder to find scenes. + """ + scenes_data = [] + for root, dirs, files in os.walk(base_path, followlinks=True): + folder_scenes_data = [] + for file in files: + name, ext = os.path.splitext(file) + if ext == ".glb": + scene = os.path.join(root, name + ".glb") + navmesh = os.path.join(root, name + ".navmesh") + if not os.path.exists(navmesh): + navmesh = "" + relpath = os.path.relpath(root, base_path) + output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name)) + data = SceneData(scene_dataset_config_file="", + scene = scene, + navmesh = navmesh, + output_dir = output_dir) + folder_scenes_data.append(data) + + # Specific check for HM3D: + # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version. + basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")] + if len(basis_scenes) != 0: + folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)] + + scenes_data.extend(folder_scenes_data) + return scenes_data + +def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET): + scenes_data = [] + + # HM3D + for split in ("minival", "train", "val", "examples"): + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"), + base_path=f"{scenes_dataset_paths['hm3d']}/{split}") + + # Gibson + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"), + base_path=scenes_dataset_paths["gibson"]) + + # Habitat test scenes (just a few) + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"), + base_path=scenes_dataset_paths["habitat-test-scenes"]) + + # ReplicaCAD (baked lightning) + scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir) + + # ScanNet + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"), + base_path=scenes_dataset_paths["scannet"]) + + # Replica + list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"), + base_path=scenes_dataset_paths["replica"]) + return scenes_data diff --git a/monst3r/croco/datasets/pairs_dataset.py b/monst3r/croco/datasets/pairs_dataset.py new file mode 100644 index 0000000..9f10752 --- /dev/null +++ b/monst3r/croco/datasets/pairs_dataset.py @@ -0,0 +1,109 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +from torch.utils.data import Dataset +from PIL import Image + +from datasets.transforms import get_pair_transforms + +def load_image(impath): + return Image.open(impath) + +def load_pairs_from_cache_file(fname, root=''): + assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, 'r') as fid: + lines = fid.read().strip().splitlines() + pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines] + return pairs + +def load_pairs_from_list_file(fname, root=''): + assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, 'r') as fid: + lines = fid.read().strip().splitlines() + pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')] + return pairs + + +def write_cache_file(fname, pairs, root=''): + if len(root)>0: + if not root.endswith('/'): root+='/' + assert os.path.isdir(root) + s = '' + for im1, im2 in pairs: + if len(root)>0: + assert im1.startswith(root), im1 + assert im2.startswith(root), im2 + s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):]) + with open(fname, 'w') as fid: + fid.write(s[:-1]) + +def parse_and_cache_all_pairs(dname, data_dir='./data/'): + if dname=='habitat_release': + dirname = os.path.join(data_dir, 'habitat_release') + assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname + cache_file = os.path.join(dirname, 'pairs.txt') + assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file + + print('Parsing pairs for dataset: '+dname) + pairs = [] + for root, dirs, files in os.walk(dirname): + if 'val' in root: continue + dirs.sort() + pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')] + print('Found {:,} pairs'.format(len(pairs))) + print('Writing cache to: '+cache_file) + write_cache_file(cache_file, pairs, root=dirname) + + else: + raise NotImplementedError('Unknown dataset: '+dname) + +def dnames_to_image_pairs(dnames, data_dir='./data/'): + """ + dnames: list of datasets with image pairs, separated by + + """ + all_pairs = [] + for dname in dnames.split('+'): + if dname=='habitat_release': + dirname = os.path.join(data_dir, 'habitat_release') + assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname + cache_file = os.path.join(dirname, 'pairs.txt') + assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file + pairs = load_pairs_from_cache_file(cache_file, root=dirname) + elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']: + dirname = os.path.join(data_dir, dname+'_crops') + assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) + list_file = os.path.join(dirname, 'listing.txt') + assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file) + pairs = load_pairs_from_list_file(list_file, root=dirname) + print(' {:s}: {:,} pairs'.format(dname, len(pairs))) + all_pairs += pairs + if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs))) + return all_pairs + + +class PairsDataset(Dataset): + + def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'): + super().__init__() + self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) + self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize) + + def __len__(self): + return len(self.image_pairs) + + def __getitem__(self, index): + im1path, im2path = self.image_pairs[index] + im1 = load_image(im1path) + im2 = load_image(im2path) + if self.transforms is not None: im1, im2 = self.transforms(im1, im2) + return im1, im2 + + +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset") + parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") + parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset") + args = parser.parse_args() + parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) diff --git a/monst3r/croco/datasets/transforms.py b/monst3r/croco/datasets/transforms.py new file mode 100644 index 0000000..216bac6 --- /dev/null +++ b/monst3r/croco/datasets/transforms.py @@ -0,0 +1,95 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +import torchvision.transforms +import torchvision.transforms.functional as F + +# "Pair": apply a transform on a pair +# "Both": apply the exact same transform to both images + +class ComposePair(torchvision.transforms.Compose): + def __call__(self, img1, img2): + for t in self.transforms: + img1, img2 = t(img1, img2) + return img1, img2 + +class NormalizeBoth(torchvision.transforms.Normalize): + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + +class ToTensorBoth(torchvision.transforms.ToTensor): + def __call__(self, img1, img2): + img1 = super().__call__(img1) + img2 = super().__call__(img2) + return img1, img2 + +class RandomCropPair(torchvision.transforms.RandomCrop): + # the crop will be intentionally different for the two images with this class + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + +class ColorJitterPair(torchvision.transforms.ColorJitter): + # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob + def __init__(self, assymetric_prob, **kwargs): + super().__init__(**kwargs) + self.assymetric_prob = assymetric_prob + def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return img + + def forward(self, img1, img2): + + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) + if torch.rand(1) < self.assymetric_prob: # assymetric: + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) + return img1, img2 + +def get_pair_transforms(transform_str, totensor=True, normalize=True): + # transform_str is eg crop224+color + trfs = [] + for s in transform_str.split('+'): + if s.startswith('crop'): + size = int(s[len('crop'):]) + trfs.append(RandomCropPair(size)) + elif s=='acolor': + trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0)) + elif s=='': # if transform_str was "" + pass + else: + raise NotImplementedError('Unknown augmentation: '+s) + + if totensor: + trfs.append( ToTensorBoth() ) + if normalize: + trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) + + if len(trfs)==0: + return None + elif len(trfs)==1: + return trfs + else: + return ComposePair(trfs) + + + + + diff --git a/monst3r/croco/demo.py b/monst3r/croco/demo.py new file mode 100644 index 0000000..91b80cc --- /dev/null +++ b/monst3r/croco/demo.py @@ -0,0 +1,55 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +from models.croco import CroCoNet +from PIL import Image +import torchvision.transforms +from torchvision.transforms import ToTensor, Normalize, Compose + +def main(): + device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu') + + # load 224x224 images and transform them to tensor + imagenet_mean = [0.485, 0.456, 0.406] + imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True) + imagenet_std = [0.229, 0.224, 0.225] + imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True) + trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)]) + image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) + image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) + + # load model + ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu') + model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device) + model.eval() + msg = model.load_state_dict(ckpt['model'], strict=True) + + # forward + with torch.inference_mode(): + out, mask, target = model(image1, image2) + + # the output is normalized, thus use the mean/std of the actual image to go back to RGB space + patchified = model.patchify(image1) + mean = patchified.mean(dim=-1, keepdim=True) + var = patchified.var(dim=-1, keepdim=True) + decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean) + # undo imagenet normalization, prepare masked image + decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor + input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor + ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor + image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None]) + masked_input_image = ((1 - image_masks) * input_image) + + # make visualization + visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4 + B, C, H, W = visualization.shape + visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W) + visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1)) + fname = "demo_output.png" + visualization.save(fname) + print('Visualization save in '+fname) + + +if __name__=="__main__": + main() diff --git a/monst3r/croco/interactive_demo.ipynb b/monst3r/croco/interactive_demo.ipynb new file mode 100644 index 0000000..6cfc960 --- /dev/null +++ b/monst3r/croco/interactive_demo.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Interactive demo of Cross-view Completion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n", + "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from models.croco import CroCoNet\n", + "from ipywidgets import interact, interactive, fixed, interact_manual\n", + "import ipywidgets as widgets\n", + "import matplotlib.pyplot as plt\n", + "import quaternion\n", + "import models.masking" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load CroCo model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n", + "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n", + "msg = model.load_state_dict(ckpt['model'], strict=True)\n", + "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", + "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", + "model = model.eval()\n", + "model = model.to(device=device)\n", + "print(msg)\n", + "\n", + "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n", + " \"\"\"\n", + " Perform Cross-View completion using two input images, specified using Numpy arrays.\n", + " \"\"\"\n", + " # Replace the mask generator\n", + " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n", + "\n", + " # ImageNet-1k color normalization\n", + " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n", + " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n", + "\n", + " normalize_input_colors = True\n", + " is_output_normalized = True\n", + " with torch.no_grad():\n", + " # Cast data to torch\n", + " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", + " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", + "\n", + " if normalize_input_colors:\n", + " ref_image = (ref_image - imagenet_mean) / imagenet_std\n", + " target_image = (target_image - imagenet_mean) / imagenet_std\n", + "\n", + " out, mask, _ = model(target_image, ref_image)\n", + " # # get target\n", + " if not is_output_normalized:\n", + " predicted_image = model.unpatchify(out)\n", + " else:\n", + " # The output only contains higher order information,\n", + " # we retrieve mean and standard deviation from the actual target image\n", + " patchified = model.patchify(target_image)\n", + " mean = patchified.mean(dim=-1, keepdim=True)\n", + " var = patchified.var(dim=-1, keepdim=True)\n", + " pred_renorm = out * (var + 1.e-6)**.5 + mean\n", + " predicted_image = model.unpatchify(pred_renorm)\n", + "\n", + " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n", + " masked_target_image = (1 - image_masks) * target_image\n", + " \n", + " if not reconstruct_unmasked_patches:\n", + " # Replace unmasked patches by their actual values\n", + " predicted_image = predicted_image * image_masks + masked_target_image\n", + "\n", + " # Unapply color normalization\n", + " if normalize_input_colors:\n", + " predicted_image = predicted_image * imagenet_std + imagenet_mean\n", + " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n", + " \n", + " # Cast to Numpy\n", + " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", + " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", + " return masked_target_image, predicted_image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n", + "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n", + "import habitat_sim\n", + "\n", + "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n", + "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n", + "\n", + "sim_cfg = habitat_sim.SimulatorConfiguration()\n", + "if use_gpu: sim_cfg.gpu_device_id = 0\n", + "sim_cfg.scene_id = scene\n", + "sim_cfg.load_semantic_mesh = False\n", + "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n", + "rgb_sensor_spec.uuid = \"color\"\n", + "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n", + "rgb_sensor_spec.resolution = (224,224)\n", + "rgb_sensor_spec.hfov = 56.56\n", + "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n", + "rgb_sensor_spec.orientation = [0, 0, 0]\n", + "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n", + "\n", + "\n", + "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n", + "sim = habitat_sim.Simulator(cfg)\n", + "if navmesh is not None:\n", + " sim.pathfinder.load_nav_mesh(navmesh)\n", + "agent = sim.initialize_agent(agent_id=0)\n", + "\n", + "def sample_random_viewpoint():\n", + " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n", + " nav_point = sim.pathfinder.get_random_navigable_point()\n", + " # Sample a random viewpoint height\n", + " viewpoint_height = np.random.uniform(1.0, 1.6)\n", + " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n", + " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n", + " return viewpoint_position, viewpoint_orientation\n", + "\n", + "def render_viewpoint(position, orientation):\n", + " agent_state = habitat_sim.AgentState()\n", + " agent_state.position = position\n", + " agent_state.rotation = orientation\n", + " agent.set_state(agent_state)\n", + " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n", + " image = viewpoint_observations['color'][:,:,:3]\n", + " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n", + " return image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sample a random reference view" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ref_position, ref_orientation = sample_random_viewpoint()\n", + "ref_image = render_viewpoint(ref_position, ref_orientation)\n", + "plt.clf()\n", + "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n", + "axes[0,0].imshow(ref_image)\n", + "for ax in axes.flatten():\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Interactive cross-view completion using CroCo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstruct_unmasked_patches = False\n", + "\n", + "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n", + " R = quaternion.as_rotation_matrix(ref_orientation)\n", + " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n", + " target_orientation = (ref_orientation\n", + " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n", + " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n", + " \n", + " ref_image = render_viewpoint(ref_position, ref_orientation)\n", + " target_image = render_viewpoint(target_position, target_orientation)\n", + "\n", + " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n", + "\n", + " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n", + " axes[0].imshow(ref_image)\n", + " axes[0].set_xlabel(\"Reference\")\n", + " axes[1].imshow(masked_target_image)\n", + " axes[1].set_xlabel(\"Masked target\")\n", + " axes[2].imshow(predicted_image)\n", + " axes[2].set_xlabel(\"Reconstruction\") \n", + " axes[3].imshow(target_image)\n", + " axes[3].set_xlabel(\"Target\")\n", + " for ax in axes.flatten():\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + "\n", + "interact(show_demo,\n", + " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n", + " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n", + " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "vscode": { + "interpreter": { + "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/monst3r/croco/models/blocks.py b/monst3r/croco/models/blocks.py new file mode 100644 index 0000000..9fc2d82 --- /dev/null +++ b/monst3r/croco/models/blocks.py @@ -0,0 +1,252 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Main encoder/decoder blocks +# -------------------------------------------------------- +# References: +# timm +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py + + +import torch +import torch.nn as nn + +from itertools import repeat +import collections.abc + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class Attention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x, xpos): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3) + q, k, v = [qkv[:,:,i] for i in range(3)] + # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, xpos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class CrossAttention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward(self, query, key, value, qpos, kpos): + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class DecoderBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x, y + + +# patch embedding +class PositionGetter(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos + +class PatchEmbed(nn.Module): + """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, init='xavier'): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.position_getter = PositionGetter() + self.init_type = init + + def forward(self, x): + B, C, H, W = x.shape + torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + def _init_weights(self): + w = self.proj.weight.data + if self.init_type == 'xavier': + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + elif self.init_type == 'kaiming': + torch.nn.init.kaiming_uniform_(w.view([w.shape[0], -1])) + elif self.init_type == 'zero': + torch.nn.init.zeros_(w) + bias = getattr(self.proj, 'bias', None) + if bias is not None: + torch.nn.init.zeros_(bias) + else: + raise ValueError(f"Unknown init type {self.init_type}") + diff --git a/monst3r/croco/models/criterion.py b/monst3r/croco/models/criterion.py new file mode 100644 index 0000000..11696c4 --- /dev/null +++ b/monst3r/croco/models/criterion.py @@ -0,0 +1,37 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Criterion to train CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import torch + +class MaskedMSE(torch.nn.Module): + + def __init__(self, norm_pix_loss=False, masked=True): + """ + norm_pix_loss: normalize each patch by their pixel mean and variance + masked: compute loss over the masked patches only + """ + super().__init__() + self.norm_pix_loss = norm_pix_loss + self.masked = masked + + def forward(self, pred, mask, target): + + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + if self.masked: + loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches + else: + loss = loss.mean() # mean loss + return loss diff --git a/monst3r/croco/models/croco.py b/monst3r/croco/models/croco.py new file mode 100644 index 0000000..14c6863 --- /dev/null +++ b/monst3r/croco/models/croco.py @@ -0,0 +1,249 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# CroCo model during pretraining +# -------------------------------------------------------- + + + +import torch +import torch.nn as nn +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 +from functools import partial + +from models.blocks import Block, DecoderBlock, PatchEmbed +from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D +from models.masking import RandomMask + + +class CroCoNet(nn.Module): + + def __init__(self, + img_size=224, # input image size + patch_size=16, # patch_size + mask_ratio=0.9, # ratios of masked tokens + enc_embed_dim=768, # encoder feature dimension + enc_depth=12, # encoder depth + enc_num_heads=12, # encoder number of heads in the transformer block + dec_embed_dim=512, # decoder feature dimension + dec_depth=8, # decoder depth + dec_num_heads=16, # decoder number of heads in the transformer block + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder + pos_embed='cosine', # positional embedding (either cosine or RoPE100) + ): + + super(CroCoNet, self).__init__() + + # patch embeddings (with initialization done as in MAE) + self._set_patch_embed(img_size, patch_size, enc_embed_dim) + + # mask generations + self._set_mask_generator(self.patch_embed.num_patches, mask_ratio) + + self.pos_embed = pos_embed + if pos_embed=='cosine': + # positional embedding of the encoder + enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float()) + # positional embedding of the decoder + dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float()) + # pos embedding in each block + self.rope = None # nothing for cosine + elif pos_embed.startswith('RoPE'): # eg RoPE100 + self.enc_pos_embed = None # nothing to add in the encoder with RoPE + self.dec_pos_embed = None # nothing to add in the decoder with RoPE + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(pos_embed[len('RoPE'):]) + self.rope = RoPE2D(freq=freq) + else: + raise NotImplementedError('Unknown pos_embed '+pos_embed) + + # transformer for the encoder + self.enc_depth = enc_depth + self.enc_embed_dim = enc_embed_dim + self.enc_blocks = nn.ModuleList([ + Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope) + for i in range(enc_depth)]) + self.enc_norm = norm_layer(enc_embed_dim) + + # masked tokens + self._set_mask_token(dec_embed_dim) + + # decoder + self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec) + + # prediction head + self._set_prediction_head(dec_embed_dim, patch_size) + + # initializer weights + self.initialize_weights() + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) + + def _set_mask_generator(self, num_patches, mask_ratio): + self.mask_generator = RandomMask(num_patches, mask_ratio) + + def _set_mask_token(self, dec_embed_dim): + self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) + + def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec): + self.dec_depth = dec_depth + self.dec_embed_dim = dec_embed_dim + # transfer from encoder to decoder + self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) + # transformer for the decoder + self.dec_blocks = nn.ModuleList([ + DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope) + for i in range(dec_depth)]) + # final norm layer + self.dec_norm = norm_layer(dec_embed_dim) + + def _set_prediction_head(self, dec_embed_dim, patch_size): + self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) + + + def initialize_weights(self): + # patch embed + self.patch_embed._init_weights() + # mask tokens + if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02) + # linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _encode_image(self, image, do_mask=False, return_all_blocks=False): + """ + image has B x 3 x img_size x img_size + do_mask: whether to perform masking or not + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + """ + # embed the image into patches (x has size B x Npatches x C) + # and get position if each return patch (pos has size B x Npatches x 2) + x, pos = self.patch_embed(image) + # add positional embedding without cls token + if self.enc_pos_embed is not None: + x = x + self.enc_pos_embed[None,...] + # apply masking + B,N,C = x.size() + if do_mask: + masks = self.mask_generator(x) + x = x[~masks].view(B, -1, C) + posvis = pos[~masks].view(B, -1, 2) + else: + B,N,C = x.size() + masks = torch.zeros((B,N), dtype=bool) + posvis = pos + # now apply the transformer encoder and normalization + if return_all_blocks: + out = [] + for blk in self.enc_blocks: + x = blk(x, posvis) + out.append(x) + out[-1] = self.enc_norm(out[-1]) + return out, pos, masks + else: + for blk in self.enc_blocks: + x = blk(x, posvis) + x = self.enc_norm(x) + return x, pos, masks + + def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): + """ + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + + masks1 can be None => assume image1 fully visible + """ + # encoder to decoder layer + visf1 = self.decoder_embed(feat1) + f2 = self.decoder_embed(feat2) + # append masked tokens to the sequence + B,Nenc,C = visf1.size() + if masks1 is None: # downstreams + f1_ = visf1 + else: # pretraining + Ntotal = masks1.size(1) + f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) + f1_[~masks1] = visf1.view(B * Nenc, C) + # add positional embedding + if self.dec_pos_embed is not None: + f1_ = f1_ + self.dec_pos_embed + f2 = f2 + self.dec_pos_embed + # apply Transformer blocks + out = f1_ + out2 = f2 + if return_all_blocks: + _out, out = out, [] + for blk in self.dec_blocks: + _out, out2 = blk(_out, out2, pos1, pos2) + out.append(_out) + out[-1] = self.dec_norm(out[-1]) + else: + for blk in self.dec_blocks: + out, out2 = blk(out, out2, pos1, pos2) + out = self.dec_norm(out) + return out + + def patchify(self, imgs): + """ + imgs: (B, 3, H, W) + x: (B, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x, channels=3): + """ + x: (N, L, patch_size**2 *channels) + imgs: (N, 3, H, W) + """ + patch_size = self.patch_embed.patch_size[0] + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) + return imgs + + def forward(self, img1, img2): + """ + img1: tensor of size B x 3 x img_size x img_size + img2: tensor of size B x 3 x img_size x img_size + + out will be B x N x (3*patch_size*patch_size) + masks are also returned as B x N just in case + """ + # encoder of the masked first image + feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) + # encoder of the second image + feat2, pos2, _ = self._encode_image(img2, do_mask=False) + # decoder + decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) + # prediction head + out = self.prediction_head(decfeat) + # get target + target = self.patchify(img1) + return out, mask1, target diff --git a/monst3r/croco/models/croco_downstream.py b/monst3r/croco/models/croco_downstream.py new file mode 100644 index 0000000..159dfff --- /dev/null +++ b/monst3r/croco/models/croco_downstream.py @@ -0,0 +1,122 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# CroCo model for downstream tasks +# -------------------------------------------------------- + +import torch + +from .croco import CroCoNet + + +def croco_args_from_ckpt(ckpt): + if 'croco_kwargs' in ckpt: # CroCo v2 released models + return ckpt['croco_kwargs'] + elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release + s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" + assert s.startswith('CroCoNet(') + return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it + else: # CroCo v1 released models + return dict() + +class CroCoDownstreamMonocularEncoder(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for monocular downstream task, only using the encoder. + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + NOTE: It works by *calling super().__init__() but with redefined setters + + """ + super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_decoder(self, *args, **kwargs): + """ No decoder """ + return + + def _set_prediction_head(self, *args, **kwargs): + """ No 'prediction head' for downstream tasks.""" + return + + def forward(self, img): + """ + img if of size batch_size x 3 x h x w + """ + B, C, H, W = img.size() + img_info = {'height': H, 'width': W} + need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers) + return self.head(out, img_info) + + +class CroCoDownstreamBinocular(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for binocular downstream task + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + """ + super(CroCoDownstreamBinocular, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head for downstream tasks, define your own head """ + return + + def encode_image_pairs(self, img1, img2, return_all_blocks=False): + """ run encoder for a pair of images + it is actually ~5% faster to concatenate the images along the batch dimension + than to encode them separately + """ + ## the two commented lines below is the naive version with separate encoding + #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) + #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) + ## and now the faster version + out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks ) + if return_all_blocks: + out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) + out2 = out2[-1] + else: + out,out2 = out.chunk(2, dim=0) + pos,pos2 = pos.chunk(2, dim=0) + return out, out2, pos, pos2 + + def forward(self, img1, img2): + B, C, H, W = img1.size() + img_info = {'height': H, 'width': W} + return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks) + if return_all_blocks: + decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks) + decout = out+decout + else: + decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks) + return self.head(decout, img_info) \ No newline at end of file diff --git a/monst3r/croco/models/curope/__init__.py b/monst3r/croco/models/curope/__init__.py new file mode 100644 index 0000000..25e3d48 --- /dev/null +++ b/monst3r/croco/models/curope/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from .curope2d import cuRoPE2D diff --git a/monst3r/croco/models/curope/curope.cpp b/monst3r/croco/models/curope/curope.cpp new file mode 100644 index 0000000..8fe9058 --- /dev/null +++ b/monst3r/croco/models/curope/curope.cpp @@ -0,0 +1,69 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include + +// forward declaration +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); + +void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) +{ + const int B = tokens.size(0); + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3) / 4; + + auto tok = tokens.accessor(); + auto pos = positions.accessor(); + + for (int b = 0; b < B; b++) { + for (int x = 0; x < 2; x++) { // y and then x (2d) + for (int n = 0; n < N; n++) { + + // grab the token position + const int p = pos[b][n][x]; + + for (int h = 0; h < H; h++) { + for (int d = 0; d < D; d++) { + // grab the two values + float u = tok[b][n][h][d+0+x*2*D]; + float v = tok[b][n][h][d+D+x*2*D]; + + // grab the cos,sin + const float inv_freq = fwd * p / powf(base, d/float(D)); + float c = cosf(inv_freq); + float s = sinf(inv_freq); + + // write the result + tok[b][n][h][d+0+x*2*D] = u*c - v*s; + tok[b][n][h][d+D+x*2*D] = v*c + u*s; + } + } + } + } + } +} + +void rope_2d( torch::Tensor tokens, // B,N,H,D + const torch::Tensor positions, // B,N,2 + const float base, + const float fwd ) +{ + TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); + TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); + TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); + TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); + TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); + TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); + + if (tokens.is_cuda()) + rope_2d_cuda( tokens, positions, base, fwd ); + else + rope_2d_cpu( tokens, positions, base, fwd ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); +} diff --git a/monst3r/croco/models/curope/curope2d.py b/monst3r/croco/models/curope/curope2d.py new file mode 100644 index 0000000..a49c12f --- /dev/null +++ b/monst3r/croco/models/curope/curope2d.py @@ -0,0 +1,40 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func (torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d( tokens, positions, base, F0 ) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d( grad_res, positions, base, -F0 ) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) + return tokens \ No newline at end of file diff --git a/monst3r/croco/models/curope/kernels.cu b/monst3r/croco/models/curope/kernels.cu new file mode 100644 index 0000000..7156cd1 --- /dev/null +++ b/monst3r/croco/models/curope/kernels.cu @@ -0,0 +1,108 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include +#include +#include +#include + +#define CHECK_CUDA(tensor) {\ + TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ + TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } +void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} + + +template < typename scalar_t > +__global__ void rope_2d_cuda_kernel( + //scalar_t* __restrict__ tokens, + torch::PackedTensorAccessor32 tokens, + const int64_t* __restrict__ pos, + const float base, + const float fwd ) + // const int N, const int H, const int D ) +{ + // tokens shape = (B, N, H, D) + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3); + + // each block update a single token, for all heads + // each thread takes care of a single output + extern __shared__ float shared[]; + float* shared_inv_freq = shared + D; + + const int b = blockIdx.x / N; + const int n = blockIdx.x % N; + + const int Q = D / 4; + // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] + // u_Y v_Y u_X v_X + + // shared memory: first, compute inv_freq + if (threadIdx.x < Q) + shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); + __syncthreads(); + + // start of X or Y part + const int X = threadIdx.x < D/2 ? 0 : 1; + const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X + + // grab the cos,sin appropriate for me + const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; + const float cos = cosf(freq); + const float sin = sinf(freq); + /* + float* shared_cos_sin = shared + D + D/4; + if ((threadIdx.x % (D/2)) < Q) + shared_cos_sin[m+0] = cosf(freq); + else + shared_cos_sin[m+Q] = sinf(freq); + __syncthreads(); + const float cos = shared_cos_sin[m+0]; + const float sin = shared_cos_sin[m+Q]; + */ + + for (int h = 0; h < H; h++) + { + // then, load all the token for this head in shared memory + shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; + __syncthreads(); + + const float u = shared[m]; + const float v = shared[m+Q]; + + // write output + if ((threadIdx.x % (D/2)) < Q) + tokens[b][n][h][threadIdx.x] = u*cos - v*sin; + else + tokens[b][n][h][threadIdx.x] = v*cos + u*sin; + } +} + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) +{ + const int B = tokens.size(0); // batch size + const int N = tokens.size(1); // sequence length + const int H = tokens.size(2); // number of heads + const int D = tokens.size(3); // dimension per head + + TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); + TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); + TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); + TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); + + // one block for each layer, one thread per local-max + const int THREADS_PER_BLOCK = D; + const int N_BLOCKS = B * N; // each block takes care of H*D values + const int SHARED_MEM = sizeof(float) * (D + D/4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { + rope_2d_cuda_kernel <<>> ( + //tokens.data_ptr(), + tokens.packed_accessor32(), + pos.data_ptr(), + base, fwd); //, N, H, D ); + })); +} diff --git a/monst3r/croco/models/curope/setup.py b/monst3r/croco/models/curope/setup.py new file mode 100644 index 0000000..230632e --- /dev/null +++ b/monst3r/croco/models/curope/setup.py @@ -0,0 +1,34 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ + # '-gencode', 'arch=compute_70,code=sm_70', + # '-gencode', 'arch=compute_75,code=sm_75', + # '-gencode', 'arch=compute_80,code=sm_80', + # '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name = 'curope', + ext_modules = [ + CUDAExtension( + name='curope', + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args = dict( + nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, + cxx=['-O3']) + ) + ], + cmdclass = { + 'build_ext': BuildExtension + }) diff --git a/monst3r/croco/models/dpt_block.py b/monst3r/croco/models/dpt_block.py new file mode 100644 index 0000000..d4ddfb7 --- /dev/null +++ b/monst3r/croco/models/dpt_block.py @@ -0,0 +1,450 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# DPT head for ViTs +# -------------------------------------------------------- +# References: +# https://github.com/isl-org/DPT +# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from typing import Union, Tuple, Iterable, List, Optional, Dict + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList([ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ]) + + return scratch + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + width_ratio=1, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + self.width_ratio = width_ratio + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + if self.width_ratio != 1: + res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear') + + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if self.width_ratio != 1: + # and output.shape[3] < self.width_ratio * output.shape[2] + #size=(image.shape[]) + if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio: + shape = 3 * output.shape[3] + else: + shape = int(self.width_ratio * 2 * output.shape[2]) + output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear') + else: + output = nn.functional.interpolate(output, scale_factor=2, + mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + return output + +def make_fusion_block(features, use_bn, width_ratio=1): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + width_ratio=width_ratio, + ) + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_cahnnels: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__(self, + num_channels: int = 1, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ('rgb',), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + last_dim: int = 32, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = 'regression', + output_width_ratio=1, + **kwargs): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None + self.head_type = head_type + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + + if self.head_type == 'regression': + # The "DPTDepthModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0) + ) + elif self.head_type == 'semseg': + # The "DPTSegmentationModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc=768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + #print(dim_tokens_enc) + + # Set up activation postprocessing layers + if isinstance(dim_tokens_enc, int): + dim_tokens_enc = 4 * [dim_tokens_enc] + + self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc] + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[0], + out_channels=self.layer_dims[0], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, stride=4, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[1], + out_channels=self.layer_dims[1], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, stride=2, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[2], + out_channels=self.layer_dims[2], + kernel_size=1, stride=1, padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[3], + out_channels=self.layer_dims[3], + kernel_size=1, stride=1, padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, stride=2, padding=1, + ) + ) + + self.act_postprocess = nn.ModuleList([ + self.act_1_postprocess, + self.act_2_postprocess, + self.act_3_postprocess, + self.act_4_postprocess + ]) + + def adapt_tokens(self, encoder_tokens): + # Adapt tokens + x = [] + x.append(encoder_tokens[:, :]) + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], image_size): + #input_info: Dict): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + H, W = image_size + + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out diff --git a/monst3r/croco/models/head_downstream.py b/monst3r/croco/models/head_downstream.py new file mode 100644 index 0000000..bd40c91 --- /dev/null +++ b/monst3r/croco/models/head_downstream.py @@ -0,0 +1,58 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Heads for downstream tasks +# -------------------------------------------------------- + +""" +A head is a module where the __init__ defines only the head hyperparameters. +A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. +The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' +""" + +import torch +import torch.nn as nn +from .dpt_block import DPTOutputAdapter + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for CroCo. + by default, hooks_idx will be equal to: + * for encoder-only: 4 equally spread layers + * for encoder+decoder: last encoder + 3 equally spread layers of the decoder + """ + + def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768], + output_width_ratio=1, num_channels=1, postprocess=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_blocks = True # backbone needs to return all layers + self.postprocess = postprocess + self.output_width_ratio = output_width_ratio + self.num_channels = num_channels + self.hooks_idx = hooks_idx + self.layer_dims = layer_dims + + def setup(self, croconet): + dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels} + if self.hooks_idx is None: + if hasattr(croconet, 'dec_blocks'): # encoder + decoder + step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] + hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + else: # encoder only + step = croconet.enc_depth//4 + hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + self.hooks_idx = hooks_idx + print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}') + dpt_args['hooks'] = self.hooks_idx + dpt_args['layer_dims'] = self.layer_dims + self.dpt = DPTOutputAdapter(**dpt_args) + dim_tokens = [croconet.enc_embed_dim if hook0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +#---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +#---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D,seq_len,device,dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D,seq_len,device,dtype] = (cos,sin) + return self.cache[D,seq_len,device,dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim==2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:,:,0], cos, sin) + x = self.apply_rope1d(x, positions[:,:,1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens \ No newline at end of file diff --git a/monst3r/croco/pretrain.py b/monst3r/croco/pretrain.py new file mode 100644 index 0000000..2c45e48 --- /dev/null +++ b/monst3r/croco/pretrain.py @@ -0,0 +1,254 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Pre-training CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time +import math +from pathlib import Path +from typing import Iterable + +import torch +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +import utils.misc as misc +from utils.misc import NativeScalerWithGradNormCount as NativeScaler +from models.croco import CroCoNet +from models.criterion import MaskedMSE +from datasets.pairs_dataset import PairsDataset + + +def get_args_parser(): + parser = argparse.ArgumentParser('CroCo pre-training', add_help=False) + # model and criterion + parser.add_argument('--model', default='CroCoNet()', type=str, help="string containing the model to build") + parser.add_argument('--norm_pix_loss', default=1, choices=[0,1], help="apply per-patch mean/std normalization before applying the loss") + # dataset + parser.add_argument('--dataset', default='habitat_release', type=str, help="training set") + parser.add_argument('--transforms', default='crop224+acolor', type=str, help="transforms to apply") # in the paper, we also use some homography and rotation, but find later that they were not useful or even harmful + # training + parser.add_argument('--seed', default=0, type=int, help="Random seed") + parser.add_argument('--batch_size', default=64, type=int, help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") + parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") + parser.add_argument('--max_epoch', default=400, type=int, help="Stop training at this epoch") + parser.add_argument('--accum_iter', default=1, type=int, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") + parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") + parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') + parser.add_argument('--amp', type=int, default=1, choices=[0,1], help="Use Automatic Mixed Precision for pretraining") + # others + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--save_freq', default=1, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') + parser.add_argument('--keep_freq', default=20, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') + parser.add_argument('--print_freq', default=20, type=int, help='frequence (number of iterations) to print infos while training') + # paths + parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output") + parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") + return parser + + + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + + print("output_dir: "+args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # auto resume + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + ## training dataset and loader + print('Building dataset for {:s} with transforms {:s}'.format(args.dataset, args.transforms)) + dataset = PairsDataset(args.dataset, trfs=args.transforms, data_dir=args.data_dir) + if world_size>1: + sampler_train = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + else: + sampler_train = torch.utils.data.RandomSampler(dataset) + data_loader_train = torch.utils.data.DataLoader( + dataset, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + + ## model + print('Loading model: {:s}'.format(args.model)) + model = eval(args.model) + print('Loading criterion: MaskedMSE(norm_pix_loss={:s})'.format(str(bool(args.norm_pix_loss)))) + criterion = MaskedMSE(norm_pix_loss=bool(args.norm_pix_loss)) + + model.to(device) + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True) + model_without_ddp = model.module + + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) # following timm: set wd as 0 for bias and norm layers + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training until {args.max_epoch} epochs") + start_time = time.time() + for epoch in range(args.start_epoch, args.max_epoch): + if world_size>1: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + log_writer=log_writer, + args=args + ) + + if args.output_dir and epoch % args.save_freq == 0 : + misc.save_model( + args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, fname='last') + + if args.output_dir and (epoch % args.keep_freq == 0 or epoch + 1 == args.max_epoch) and (epoch>0 or args.max_epoch==1): + misc.save_model( + args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch,} + + if args.output_dir and misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + log_writer=None, + args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + accum_iter = args.accum_iter + + optimizer.zero_grad() + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + for data_iter_step, (image1, image2) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + + # we use a per iteration lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + out, mask, target = model(image1, image2) + loss = criterion(out, mask, target) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and ((data_iter_step + 1) % (accum_iter*args.print_freq)) == 0: + # x-axis is based on epoch_1000x in the tensorboard, calibrating differences curves when batch size changes + epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) + log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('lr', lr, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/monst3r/croco/stereoflow/README.MD b/monst3r/croco/stereoflow/README.MD new file mode 100644 index 0000000..8159538 --- /dev/null +++ b/monst3r/croco/stereoflow/README.MD @@ -0,0 +1,318 @@ +## CroCo-Stereo and CroCo-Flow + +This README explains how to use CroCo-Stereo and CroCo-Flow as well as how they were trained. +All commands should be launched from the root directory. + +### Simple inference example + +We provide a simple inference exemple for CroCo-Stereo and CroCo-Flow in the Totebook `croco-stereo-flow-demo.ipynb`. +Before running it, please download the trained models with: +``` +bash stereoflow/download_model.sh crocostereo.pth +bash stereoflow/download_model.sh crocoflow.pth +``` + +### Prepare data for training or evaluation + +Put the datasets used for training/evaluation in `./data/stereoflow` (or update the paths at the top of `stereoflow/datasets_stereo.py` and `stereoflow/datasets_flow.py`). +Please find below on the file structure should look for each dataset: +
+FlyingChairs + +``` +./data/stereoflow/FlyingChairs/ +└───chairs_split.txt +└───data/ + └─── ... +``` +
+ +
+MPI-Sintel + +``` +./data/stereoflow/MPI-Sintel/ +└───training/ +│ └───clean/ +│ └───final/ +│ └───flow/ +└───test/ + └───clean/ + └───final/ +``` +
+ +
+SceneFlow (including FlyingThings) + +``` +./data/stereoflow/SceneFlow/ +└───Driving/ +│ └───disparity/ +│ └───frames_cleanpass/ +│ └───frames_finalpass/ +└───FlyingThings/ +│ └───disparity/ +│ └───frames_cleanpass/ +│ └───frames_finalpass/ +│ └───optical_flow/ +└───Monkaa/ + └───disparity/ + └───frames_cleanpass/ + └───frames_finalpass/ +``` +
+ +
+TartanAir + +``` +./data/stereoflow/TartanAir/ +└───abandonedfactory/ +│ └───.../ +└───abandonedfactory_night/ +│ └───.../ +└───.../ +``` +
+ +
+Booster + +``` +./data/stereoflow/booster_gt/ +└───train/ + └───balanced/ + └───Bathroom/ + └───Bedroom/ + └───... +``` +
+ +
+CREStereo + +``` +./data/stereoflow/crenet_stereo_trainset/ +└───stereo_trainset/ + └───crestereo/ + └───hole/ + └───reflective/ + └───shapenet/ + └───tree/ +``` +
+ +
+ETH3D Two-view Low-res + +``` +./data/stereoflow/eth3d_lowres/ +└───test/ +│ └───lakeside_1l/ +│ └───... +└───train/ +│ └───delivery_area_1l/ +│ └───... +└───train_gt/ + └───delivery_area_1l/ + └───... +``` +
+ +
+KITTI 2012 + +``` +./data/stereoflow/kitti-stereo-2012/ +└───testing/ +│ └───colored_0/ +│ └───colored_1/ +└───training/ + └───colored_0/ + └───colored_1/ + └───disp_occ/ + └───flow_occ/ +``` +
+ +
+KITTI 2015 + +``` +./data/stereoflow/kitti-stereo-2015/ +└───testing/ +│ └───image_2/ +│ └───image_3/ +└───training/ + └───image_2/ + └───image_3/ + └───disp_occ_0/ + └───flow_occ/ +``` +
+ +
+Middlebury + +``` +./data/stereoflow/middlebury +└───2005/ +│ └───train/ +│ └───Art/ +│ └───... +└───2006/ +│ └───Aloe/ +│ └───Baby1/ +│ └───... +└───2014/ +│ └───Adirondack-imperfect/ +│ └───Adirondack-perfect/ +│ └───... +└───2021/ +│ └───data/ +│ └───artroom1/ +│ └───artroom2/ +│ └───... +└───MiddEval3_F/ + └───test/ + │ └───Australia/ + │ └───... + └───train/ + └───Adirondack/ + └───... +``` +
+ +
+Spring + +``` +./data/stereoflow/spring/ +└───test/ +│ └───0003/ +│ └───... +└───train/ + └───0001/ + └───... +``` +
+ + +### CroCo-Stereo + +##### Main model + +The main training of CroCo-Stereo was performed on a series of datasets, and it was used as it for Middlebury v3 benchmark. + +``` +# Download the model +bash stereoflow/download_model.sh crocostereo.pth +# Middlebury v3 submission +python stereoflow/test.py --model stereoflow_models/crocostereo.pth --dataset "MdEval3('all_full')" --save submission --tile_overlap 0.9 +# Training command that was used, using checkpoint-last.pth +python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/ +# or it can be launched on multiple gpus (while maintaining the effective batch size), e.g. on 3 gpus: +torchrun --nproc_per_node 3 stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 2 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/ +``` + +For evaluation of validation set, we also provide the model trained on the `subtrain` subset of the training sets. + +``` +# Download the model +bash stereoflow/download_model.sh crocostereo_subtrain.pth +# Evaluation on validation sets +python stereoflow/test.py --model stereoflow_models/crocostereo_subtrain.pth --dataset "MdEval3('subval_full')+ETH3DLowRes('subval')+SceneFlow('test_finalpass')+SceneFlow('test_cleanpass')" --save metrics --tile_overlap 0.9 +# Training command that was used (same as above but on subtrain, using checkpoint-best.pth), can also be launched on multiple gpus +python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('subtrain')+50*Md05('subtrain')+50*Md06('subtrain')+50*Md14('subtrain')+50*Md21('subtrain')+50*MdEval3('subtrain_full')+Booster('subtrain_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_subtrain/ +``` + +##### Other models + +
+ Model for ETH3D + The model used for the submission on ETH3D is trained with the same command but using an unbounded Laplacian loss. + + # Download the model + bash stereoflow/download_model.sh crocostereo_eth3d.pth + # ETH3D submission + python stereoflow/test.py --model stereoflow_models/crocostereo_eth3d.pth --dataset "ETH3DLowRes('all')" --save submission --tile_overlap 0.9 + # Training command that was used + python -u stereoflow/train.py stereo --criterion "LaplacianLoss()" --tile_conf_mode conf_expbeta3 --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_eth3d/ + +
+ +
+ Main model finetuned on Kitti + + # Download the model + bash stereoflow/download_model.sh crocostereo_finetune_kitti.pth + # Kitti submission + python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.9 + # Training that was used + python -u stereoflow/train.py stereo --crop 352 1216 --criterion "LaplacianLossBounded2()" --dataset "Kitti12('train')+Kitti15('train')" --lr 3e-5 --batch_size 1 --accum_iter 6 --epochs 20 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_kitti/ --save_every 5 +
+ +
+ Main model finetuned on Spring + + # Download the model + bash stereoflow/download_model.sh crocostereo_finetune_spring.pth + # Spring submission + python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9 + # Training command that was used + python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "Spring('train')" --lr 3e-5 --batch_size 6 --epochs 8 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_spring/ +
+ +
+ Smaller models + To train CroCo-Stereo with smaller CroCo pretrained models, simply replace the --pretrained argument. To download the smaller CroCo-Stereo models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use bash stereoflow/download_model.sh crocostereo_subtrain_vitb_smalldecoder.pth, and for the model with a ViT-Base encoder and a Base decoder, use bash stereoflow/download_model.sh crocostereo_subtrain_vitb_basedecoder.pth. +
+ + +### CroCo-Flow + +##### Main model + +The main training of CroCo-Flow was performed on the FlyingThings, FlyingChairs, MPI-Sintel and TartanAir datasets. +It was used for our submission to the MPI-Sintel benchmark. + +``` +# Download the model +bash stereoflow/download_model.sh crocoflow.pth +# Evaluation +python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --save metrics --tile_overlap 0.9 +# Sintel submission +python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('test_allpass')" --save submission --tile_overlap 0.9 +# Training command that was used, with checkpoint-best.pth +python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "40*MPISintel('subtrain_cleanpass')+40*MPISintel('subtrain_finalpass')+4*FlyingThings('train_allpass')+4*FlyingChairs('train')+TartanAir('train')" --val_dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --lr 2e-5 --batch_size 8 --epochs 240 --img_per_epoch 30000 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocoflow/main/ +``` + +##### Other models + +
+ Main model finetuned on Kitti + + # Download the model + bash stereoflow/download_model.sh crocoflow_finetune_kitti.pth + # Kitti submission + python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.99 + # Training that was used, with checkpoint-last.pth + python -u stereoflow/train.py flow --crop 352 1216 --criterion "LaplacianLossBounded()" --dataset "Kitti15('train')+Kitti12('train')" --lr 2e-5 --batch_size 1 --accum_iter 8 --epochs 150 --save_every 5 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_kitti/ +
+ +
+ Main model finetuned on Spring + + # Download the model + bash stereoflow/download_model.sh crocoflow_finetune_spring.pth + # Spring submission + python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9 + # Training command that was used, with checkpoint-last.pth + python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "Spring('train')" --lr 2e-5 --batch_size 8 --epochs 12 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_spring/ +
+ +
+ Smaller models + To train CroCo-Flow with smaller CroCo pretrained models, simply replace the --pretrained argument. To download the smaller CroCo-Flow models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use bash stereoflow/download_model.sh crocoflow_vitb_smalldecoder.pth, and for the model with a ViT-Base encoder and a Base decoder, use bash stereoflow/download_model.sh crocoflow_vitb_basedecoder.pth. +
diff --git a/monst3r/croco/stereoflow/augmentor.py b/monst3r/croco/stereoflow/augmentor.py new file mode 100644 index 0000000..69e6117 --- /dev/null +++ b/monst3r/croco/stereoflow/augmentor.py @@ -0,0 +1,290 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Data augmentation for training stereo and flow +# -------------------------------------------------------- + +# References +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/stereo/transforms.py +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/transforms.py + + +import numpy as np +import random +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torchvision.transforms.functional as FF + +class StereoAugmentor(object): + + def __init__(self, crop_size, scale_prob=0.5, scale_xonly=True, lhth=800., lminscale=0.0, lmaxscale=1.0, hminscale=-0.2, hmaxscale=0.4, scale_interp_nearest=True, rightjitterprob=0.5, v_flip_prob=0.5, color_aug_asym=True, color_choice_prob=0.5): + self.crop_size = crop_size + self.scale_prob = scale_prob + self.scale_xonly = scale_xonly + self.lhth = lhth + self.lminscale = lminscale + self.lmaxscale = lmaxscale + self.hminscale = hminscale + self.hmaxscale = hmaxscale + self.scale_interp_nearest = scale_interp_nearest + self.rightjitterprob = rightjitterprob + self.v_flip_prob = v_flip_prob + self.color_aug_asym = color_aug_asym + self.color_choice_prob = color_choice_prob + + def _random_scale(self, img1, img2, disp): + ch,cw = self.crop_size + h,w = img1.shape[:2] + if self.scale_prob>0. and np.random.rand()1.: + scale_x = clip_scale + scale_y = scale_x if not self.scale_xonly else 1.0 + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + disp = cv2.resize(disp, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR if not self.scale_interp_nearest else cv2.INTER_NEAREST) * scale_x + return img1, img2, disp + + def _random_crop(self, img1, img2, disp): + h,w = img1.shape[:2] + ch,cw = self.crop_size + assert ch<=h and cw<=w, (img1.shape, h,w,ch,cw) + offset_x = np.random.randint(w - cw + 1) + offset_y = np.random.randint(h - ch + 1) + img1 = img1[offset_y:offset_y+ch,offset_x:offset_x+cw] + img2 = img2[offset_y:offset_y+ch,offset_x:offset_x+cw] + disp = disp[offset_y:offset_y+ch,offset_x:offset_x+cw] + return img1, img2, disp + + def _random_vflip(self, img1, img2, disp): + # vertical flip + if self.v_flip_prob>0 and np.random.rand() < self.v_flip_prob: + img1 = np.copy(np.flipud(img1)) + img2 = np.copy(np.flipud(img2)) + disp = np.copy(np.flipud(disp)) + return img1, img2, disp + + def _random_rotate_shift_right(self, img2): + if self.rightjitterprob>0. and np.random.rand() 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow = np.inf * np.ones([ht1, wd1, 2], dtype=np.float32) # invalid value every where, before we fill it with the correct ones + flow[yy, xx] = flow1 + return flow + + def spatial_transform(self, img1, img2, flow, dname): + + if np.random.rand() < self.spatial_aug_prob: + # randomly sample scale + ht, wd = img1.shape[:2] + clip_min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + min_scale, max_scale = self.min_scale, self.max_scale + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_x = np.clip(scale_x, clip_min_scale, None) + scale_y = np.clip(scale_y, clip_min_scale, None) + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = self._resize_flow(flow, scale_x, scale_y, factor=2.0 if dname=='Spring' else 1.0) + elif dname=="Spring": + flow = self._resize_flow(flow, 1.0, 1.0, factor=2.0) + + if self.h_flip_prob>0. and np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if self.v_flip_prob>0. and np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + # In case no cropping + if img1.shape[0] - self.crop_size[0] > 0: + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + else: + y0 = 0 + if img1.shape[1] - self.crop_size[1] > 0: + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + else: + x0 = 0 + + img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow, dname): + img1, img2, flow = self.spatial_transform(img1, img2, flow, dname) + img1, img2 = self.color_transform(img1, img2) + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + return img1, img2, flow \ No newline at end of file diff --git a/monst3r/croco/stereoflow/criterion.py b/monst3r/croco/stereoflow/criterion.py new file mode 100644 index 0000000..57792eb --- /dev/null +++ b/monst3r/croco/stereoflow/criterion.py @@ -0,0 +1,251 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Losses, metrics per batch, metrics per dataset +# -------------------------------------------------------- + +import torch +from torch import nn +import torch.nn.functional as F + +def _get_gtnorm(gt): + if gt.size(1)==1: # stereo + return gt + # flow + return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW + +############ losses without confidence + +class L1Loss(nn.Module): + + def __init__(self, max_gtnorm=None): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = False + + def _error(self, gt, predictions): + return torch.abs(gt-predictions) + + def forward(self, predictions, gt, inspect=False): + mask = torch.isfinite(gt) + if self.max_gtnorm is not None: + mask *= _get_gtnorm(gt).expand(-1,gt.size(1),-1,-1) which is a constant + + +class LaplacianLossBounded(nn.Module): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b + def __init__(self, max_gtnorm=10000., a=0.25, b=4.): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:,0,:,:] + if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant + +class LaplacianLossBounded2(nn.Module): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b + def __init__(self, max_gtnorm=None, a=3.0, b=3.0): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:,0,:,:] + if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant + +############## metrics per batch + +class StereoMetrics(nn.Module): + + def __init__(self, do_quantile=False): + super().__init__() + self.bad_ths = [0.5,1,2,3] + self.do_quantile = do_quantile + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + gtcopy = gt.clone() + mask = torch.isfinite(gtcopy) + gtcopy[~mask] = 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0 + Npx = mask.view(B,-1).sum(dim=1) + L1error = (torch.abs(gtcopy-predictions)*mask).view(B,-1) + L2error = (torch.square(gtcopy-predictions)*mask).view(B,-1) + # avgerr + metrics['avgerr'] = torch.mean(L1error.sum(dim=1)/Npx ) + # rmse + metrics['rmse'] = torch.sqrt(L2error.sum(dim=1)/Npx).mean(dim=0) + # err > t for t in [0.5,1,2,3] + for ths in self.bad_ths: + metrics['bad@{:.1f}'.format(ths)] = (((L1error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100 + return metrics + +class FlowMetrics(nn.Module): + def __init__(self): + super().__init__() + self.bad_ths = [1,3,5] + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + mask = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + Npx = mask.view(B,-1).sum(dim=1) + gtcopy = gt.clone() # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored + gtcopy[:,0,:,:][~mask] = 999999.0 + gtcopy[:,1,:,:][~mask] = 999999.0 + L1error = (torch.abs(gtcopy-predictions).sum(dim=1)*mask).view(B,-1) + L2error = (torch.sqrt(torch.sum(torch.square(gtcopy-predictions),dim=1))*mask).view(B,-1) + metrics['L1err'] = torch.mean(L1error.sum(dim=1)/Npx ) + metrics['EPE'] = torch.mean(L2error.sum(dim=1)/Npx ) + for ths in self.bad_ths: + metrics['bad@{:.1f}'.format(ths)] = (((L2error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100 + return metrics + +############## metrics per dataset +## we update the average and maintain the number of pixels while adding data batch per batch +## at the beggining, call reset() +## after each batch, call add_batch(...) +## at the end: call get_results() + +class StereoDatasetMetrics(nn.Module): + + def __init__(self): + super().__init__() + self.bad_ths = [0.5,1,2,3] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self._metrics = None + + def add_batch(self, predictions, gt): + assert predictions.size(1)==1, predictions.size() + assert gt.size(1)==1, gt.size() + if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ... + L1err = torch.minimum( torch.minimum( torch.minimum( + torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1), + torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1)) + valid = torch.isfinite(L1err) + else: + valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt-predictions),dim=1) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew + self.agg_N = Nnew + for i,th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L1err[valid]>th).sum().cpu() + + def _compute_metrics(self): + if self._metrics is not None: return + out = {} + out['L1err'] = self.agg_L1err.item() + for i,th in enumerate(self.bad_ths): + out['bad@{:.1f}'.format(th)] = (float(self.agg_Nbad[i]) / self.agg_N).item() * 100.0 + self._metrics = out + + def get_results(self): + self._compute_metrics() # to avoid recompute them multiple times + return self._metrics + +class FlowDatasetMetrics(nn.Module): + + def __init__(self): + super().__init__() + self.bad_ths = [0.5,1,3,5] + self.speed_ths = [(0,10),(10,40),(40,torch.inf)] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self.agg_EPEspeed = [torch.tensor(0.0) for _ in self.speed_ths] # EPE per speed bin so far + self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far + self._metrics = None + self.pairname_results = {} + + def add_batch(self, predictions, gt): + assert predictions.size(1)==2, predictions.size() + assert gt.size(1)==2, gt.size() + if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ... + L1err = torch.minimum( torch.minimum( torch.minimum( + torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1), + torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1)) + L2err = torch.minimum( torch.minimum( torch.minimum( + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]-predictions),dim=1)), + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]-predictions),dim=1))), + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]-predictions),dim=1))), + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]-predictions),dim=1))) + valid = torch.isfinite(L1err) + gtspeed = (torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]),dim=1)) +\ + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]),dim=1)) ) / 4.0 # let's just average them + else: + valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt-predictions),dim=1) + L2err = torch.sqrt(torch.sum(torch.square(gt-predictions),dim=1)) + gtspeed = torch.sqrt(torch.sum(torch.square(gt),dim=1)) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew + self.agg_L2err = float(self.agg_N)/Nnew * self.agg_L2err + L2err[valid].mean().cpu() * float(N)/Nnew + self.agg_N = Nnew + for i,th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L2err[valid]>th).sum().cpu() + for i,(th1,th2) in enumerate(self.speed_ths): + vv = (gtspeed[valid]>=th1) * (gtspeed[valid] don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len(self.pairnames) # each pairname is typically of the form (str, int1, int2) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + img1name = self.pairname_to_img1name(pairname) + img2name = self.pairname_to_img2name(pairname) + flowname = self.pairname_to_flowname(pairname) if self.pairname_to_flowname is not None else None + + # load images and disparities + img1 = _read_img(img1name) + img2 = _read_img(img2name) + flow = self.load_flow(flowname) if flowname is not None else None + + # apply augmentations + if self.augmentor is not None: + img1, img2, flow = self.augmentor(img1, img2, flow, self.name) + + if self.totensor: + img1 = img_to_tensor(img1) + img2 = img_to_tensor(img2) + if flow is not None: + flow = flow_to_tensor(flow) + else: + flow = torch.tensor([]) # to allow dataloader batching with default collate_gn + pairname = str(pairname) # transform potential tuple to str to be able to batch it + + return img1, img2, flow, pairname + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f'{self.__class__.__name__}_{self.split}' + + def __repr__(self): + s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' + if self.rmul==1: + s+=f'\n\tnum pairs: {len(self.pairnames)}' + else: + s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name+'.pkl') + if osp.isfile(cache_file): + with open(cache_file, 'rb') as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, 'wb') as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + +class TartanAirDataset(FlowDataset): + + def _prepare_data(self): + self.name = "TartanAir" + self._set_root() + assert self.split in ['train'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[2])) + self.pairname_to_flowname = lambda pairname: osp.join(self.root, pairname[0], 'flow/{:06d}_{:06d}_flow.npy'.format(pairname[1],pairname[2])) + self.pairname_to_str = lambda pairname: os.path.join(pairname[0][pairname[0].find('/')+1:], '{:06d}_{:06d}'.format(pairname[1], pairname[2])) + self.load_flow = _read_numpy_flow + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + pairs = [(osp.join(s,s,difficulty,Pxxx),int(a[:6]),int(a[:6])+1) for s in seqs for difficulty in ['Easy','Hard'] for Pxxx in sorted(os.listdir(osp.join(self.root,s,s,difficulty))) for a in sorted(os.listdir(osp.join(self.root,s,s,difficulty,Pxxx,'image_left/')))[:-1]] + assert len(pairs)==306268, "incorrect parsing of pairs in TartanAir" + tosave = {'train': pairs} + return tosave + +class FlyingChairsDataset(FlowDataset): + + def _prepare_data(self): + self.name = "FlyingChairs" + self._set_root() + assert self.split in ['train','val'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, 'data', pairname+'_img1.ppm') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, 'data', pairname+'_img2.ppm') + self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'data', pairname+'_flow.flo') + self.pairname_to_str = lambda pairname: pairname + self.load_flow = _read_flo_file + + def _build_cache(self): + split_file = osp.join(self.root, 'chairs_split.txt') + split_list = np.loadtxt(split_file, dtype=np.int32) + trainpairs = ['{:05d}'.format(i) for i in np.where(split_list==1)[0]+1] + valpairs = ['{:05d}'.format(i) for i in np.where(split_list==2)[0]+1] + assert len(trainpairs)==22232 and len(valpairs)==640, "incorrect parsing of pairs in MPI-Sintel" + tosave = {'train': trainpairs, 'val': valpairs} + return tosave + +class FlyingThingsDataset(FlowDataset): + + def _prepare_data(self): + self.name = "FlyingThings" + self._set_root() + assert self.split in [f'{set_}_{pass_}pass{camstr}' for set_ in ['train','test','test1024'] for camstr in ['','_rightcam'] for pass_ in ['clean','final','all']] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[2])) + self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'optical_flow', pairname[0], 'OpticalFlowInto{f:s}_{i:04d}_{c:s}.pfm'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' )) + self.pairname_to_str = lambda pairname: os.path.join(pairname[3]+'pass', pairname[0], 'Into{f:s}_{i:04d}_{c:s}'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' )) + self.load_flow = _read_pfm_flow + + def _build_cache(self): + tosave = {} + # train and test splits for the different passes + for set_ in ['train', 'test']: + sroot = osp.join(self.root, 'optical_flow', set_.upper()) + fname_to_i = lambda f: int(f[len('OpticalFlowIntoFuture_'):-len('_L.pfm')]) + pp = [(osp.join(set_.upper(), d, s, 'into_future/left'),fname_to_i(fname)) for d in sorted(os.listdir(sroot)) for s in sorted(os.listdir(osp.join(sroot,d))) for fname in sorted(os.listdir(osp.join(sroot,d, s, 'into_future/left')))[:-1]] + pairs = [(a,i,i+1) for a,i in pp] + pairs += [(a.replace('into_future','into_past'),i+1,i) for a,i in pp] + assert len(pairs)=={'train': 40302, 'test': 7866}[set_], "incorrect parsing of pairs Flying Things" + for cam in ['left','right']: + camstr = '' if cam=='left' else f'_{cam}cam' + for pass_ in ['final', 'clean']: + tosave[f'{set_}_{pass_}pass{camstr}'] = [(a.replace('left',cam),i,j,pass_) for a,i,j in pairs] + tosave[f'{set_}_allpass{camstr}'] = tosave[f'{set_}_cleanpass{camstr}'] + tosave[f'{set_}_finalpass{camstr}'] + # test1024: this is the same split as unimatch 'validation' split + # see https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/datasets.py#L229 + test1024_nsamples = 1024 + alltest_nsamples = len(tosave['test_cleanpass']) # 7866 + stride = alltest_nsamples // test1024_nsamples + remove = alltest_nsamples % test1024_nsamples + for cam in ['left','right']: + camstr = '' if cam=='left' else f'_{cam}cam' + for pass_ in ['final','clean']: + tosave[f'test1024_{pass_}pass{camstr}'] = sorted(tosave[f'test_{pass_}pass{camstr}'])[:-remove][::stride] # warning, it was not sorted before + assert len(tosave['test1024_cleanpass'])==1024, "incorrect parsing of pairs in Flying Things" + tosave[f'test1024_allpass{camstr}'] = tosave[f'test1024_cleanpass{camstr}'] + tosave[f'test1024_finalpass{camstr}'] + return tosave + + +class MPISintelDataset(FlowDataset): + + def _prepare_data(self): + self.name = "MPISintel" + self._set_root() + assert self.split in [s+'_'+p for s in ['train','test','subval','subtrain'] for p in ['cleanpass','finalpass','allpass']] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1]+1)) + self.pairname_to_flowname = lambda pairname: None if pairname[0].startswith('test/') else osp.join(self.root, pairname[0].replace('/clean/','/flow/').replace('/final/','/flow/'), 'frame_{:04d}.flo'.format(pairname[1])) + self.pairname_to_str = lambda pairname: osp.join(pairname[0], 'frame_{:04d}'.format(pairname[1])) + self.load_flow = _read_flo_file + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root+'training/clean')) + trainpairs = [ (osp.join('training/clean', s),i) for s in trainseqs for i in range(1, len(os.listdir(self.root+'training/clean/'+s)))] + subvalseqs = ['temple_2','temple_3'] + subtrainseqs = [s for s in trainseqs if s not in subvalseqs] + subvalpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subvalseqs)] + subtrainpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subtrainseqs)] + testseqs = sorted(os.listdir(self.root+'test/clean')) + testpairs = [ (osp.join('test/clean', s),i) for s in testseqs for i in range(1, len(os.listdir(self.root+'test/clean/'+s)))] + assert len(trainpairs)==1041 and len(testpairs)==552 and len(subvalpairs)==98 and len(subtrainpairs)==943, "incorrect parsing of pairs in MPI-Sintel" + tosave = {} + tosave['train_cleanpass'] = trainpairs + tosave['test_cleanpass'] = testpairs + tosave['subval_cleanpass'] = subvalpairs + tosave['subtrain_cleanpass'] = subtrainpairs + for t in ['train','test','subval','subtrain']: + tosave[t+'_finalpass'] = [(p.replace('/clean/','/final/'),i) for p,i in tosave[t+'_cleanpass']] + tosave[t+'_allpass'] = tosave[t+'_cleanpass'] + tosave[t+'_finalpass'] + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, _time): + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, 'submission', self.pairname_to_str(pairname)+'.flo') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowFile(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split == 'test_allpass' + bundle_exe = "/nfs/data/ffs-3d/datasets/StereoFlow/MPI-Sintel/bundler/linux-x64/bundler" # eg + if os.path.isfile(bundle_exe): + cmd = f'{bundle_exe} "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at: "{outdir}/submission/bundled.lzma"') + else: + print('Could not find bundler executable for submission.') + print('Please download it and run:') + print(f' "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"') + +class SpringDataset(FlowDataset): + + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ['train','test','subtrain','subval'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4]+(1 if pairname[2]=='FW' else -1))) + self.pairname_to_flowname = lambda pairname: None if pairname[0]=='test' else osp.join(self.root, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5') + self.pairname_to_str = lambda pairname: osp.join(pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}') + self.load_flow = _read_hdf5_flow + + def _build_cache(self): + # train + trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) + trainpairs = [] + for leftright in ['left','right']: + for fwbw in ['FW','BW']: + trainpairs += [('train',s,fwbw,leftright,int(f[len(f'flow_{fwbw}_{leftright}_'):-len('.flo5')])) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,f'flow_{fwbw}_{leftright}')))] + # test + testseqs = sorted(os.listdir( osp.join(self.root,'test'))) + testpairs = [] + for leftright in ['left','right']: + testpairs += [('test',s,'FW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]] + testpairs += [('test',s,'BW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])+1) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]] + # subtrain / subval + subtrainpairs = [p for p in trainpairs if p[1]!='0041'] + subvalpairs = [p for p in trainpairs if p[1]=='0041'] + assert len(trainpairs)==19852 and len(testpairs)==3960 and len(subtrainpairs)==19472 and len(subvalpairs)==380, "incorrect parsing of pairs in Spring" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + assert prediction.dtype==np.float32 + outfile = osp.join(outdir, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlo5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + exe = "{self.root}/flow_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/test/flow_submission.hdf5') + else: + print('Could not find flow_subsampling executable for submission.') + print('Please download it and run:') + print(f'cd "{outdir}/test"; .') + + +class Kitti12Dataset(FlowDataset): + + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ['train','test'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png') + self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/flow_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] + testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] + assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" + tosave = {'train': trainseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti12_flow_results.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti12_flow_results.zip') + + +class Kitti15Dataset(FlowDataset): + + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ['train','subtrain','subval','test'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png') + self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/flow_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] + subtrainseqs = trainseqs[:-10] + subvalseqs = trainseqs[-10:] + testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] + assert len(trainseqs)==200 and len(subtrainseqs)==190 and len(subvalseqs)==10 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" + tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, 'flow', pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti15_flow_results.zip" flow' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti15_flow_results.zip') + + +import cv2 +def _read_numpy_flow(filename): + return np.load(filename) + +def _read_pfm_flow(filename): + f, _ = _read_pfm(filename) + assert np.all(f[:,:,2]==0.0) + return np.ascontiguousarray(f[:,:,:2]) + +TAG_FLOAT = 202021.25 # tag to check the sanity of the file +TAG_STRING = 'PIEH' # string containing the tag +MIN_WIDTH = 1 +MAX_WIDTH = 99999 +MIN_HEIGHT = 1 +MAX_HEIGHT = 99999 +def readFlowFile(filename): + """ + readFlowFile() reads a flow file into a 2-band np.array. + if does not exist, an IOError is raised. + if does not finish by '.flo' or the tag, the width, the height or the file's size is illegal, an Expcetion is raised. + ---- PARAMETERS ---- + filename: string containg the name of the file to read a flow + ---- OUTPUTS ---- + a np.array of dimension (height x width x 2) containing the flow of type 'float32' + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception("readFlowFile({:s}): filename must finish with '.flo'".format(filename)) + + # open the file and read it + with open(filename,'rb') as f: + # check tag + tag = struct.unpack('f',f.read(4))[0] + if tag != TAG_FLOAT: + raise Exception("flow_utils.readFlowFile({:s}): wrong tag".format(filename)) + # read dimension + w,h = struct.unpack('ii',f.read(8)) + if w < MIN_WIDTH or w > MAX_WIDTH: + raise Exception("flow_utils.readFlowFile({:s}: illegal width {:d}".format(filename,w)) + if h < MIN_HEIGHT or h > MAX_HEIGHT: + raise Exception("flow_utils.readFlowFile({:s}: illegal height {:d}".format(filename,h)) + flow = np.fromfile(f,'float32') + if not flow.shape == (h*w*2,): + raise Exception("flow_utils.readFlowFile({:s}: illegal size of the file".format(filename)) + flow.shape = (h,w,2) + return flow + +def writeFlowFile(flow,filename): + """ + writeFlowFile(flow,) write flow to the file . + if does not exist, an IOError is raised. + if does not finish with '.flo' or the flow has not 2 bands, an Exception is raised. + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to write + filename: string containg the name of the file to write a flow + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception("flow_utils.writeFlowFile(,{:s}): filename must finish with '.flo'".format(filename)) + + if not flow.shape[2:] == (2,): + raise Exception("flow_utils.writeFlowFile(,{:s}): must have 2 bands".format(filename)) + + + # open the file and write it + with open(filename,'wb') as f: + # write TAG + f.write( TAG_STRING.encode('utf-8') ) + # write dimension + f.write( struct.pack('ii',flow.shape[1],flow.shape[0]) ) + # write the flow + + flow.astype(np.float32).tofile(f) + +_read_flo_file = readFlowFile + +def _read_kitti_flow(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + valid = flow[:, :, 2]>0 + flow = flow[:, :, :2] + flow = (flow - 2 ** 15) / 64.0 + flow[~valid,0] = np.inf + flow[~valid,1] = np.inf + return flow +_read_hd1k_flow = _read_kitti_flow + + +def writeFlowKitti(filename, uv): + uv = 64.0 * uv + 2 ** 15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + +def writeFlo5File(flow, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5) + +def _read_hdf5_flow(filename): + flow = np.asarray(h5py.File(filename)['flow']) + flow[np.isnan(flow)] = np.inf # make invalid values as +inf + return flow.astype(np.float32) + +# flow visualization +RY = 15 +YG = 6 +GC = 4 +CB = 11 +BM = 13 +MR = 6 +UNKNOWN_THRESH = 1e9 + +def colorTest(): + """ + flow_utils.colorTest(): display an example of image showing the color encoding scheme + """ + import matplotlib.pylab as plt + truerange = 1 + h,w = 151,151 + trange = truerange*1.04 + s2 = round(h/2) + x,y = np.meshgrid(range(w),range(h)) + u = x*trange/s2-trange + v = y*trange/s2-trange + img = _computeColor(np.concatenate((u[:,:,np.newaxis],v[:,:,np.newaxis]),2)/trange/np.sqrt(2)) + plt.imshow(img) + plt.axis('off') + plt.axhline(round(h/2),color='k') + plt.axvline(round(w/2),color='k') + +def flowToColor(flow, maxflow=None, maxmaxflow=None, saturate=False): + """ + flow_utils.flowToColor(flow): return a color code flow field, normalized based on the maximum l2-norm of the flow + flow_utils.flowToColor(flow,maxflow): return a color code flow field, normalized by maxflow + ---- PARAMETERS ---- + flow: flow to display of shape (height x width x 2) + maxflow (default:None): if given, normalize the flow by its value, otherwise by the flow norm + maxmaxflow (default:None): if given, normalize the flow by the max of its value and the flow norm + ---- OUTPUT ---- + an np.array of shape (height x width x 3) of type uint8 containing a color code of the flow + """ + h,w,n = flow.shape + # check size of flow + assert n == 2, "flow_utils.flowToColor(flow): flow must have 2 bands" + # fix unknown flow + unknown_idx = np.max(np.abs(flow),2)>UNKNOWN_THRESH + flow[unknown_idx] = 0.0 + # compute max flow if needed + if maxflow is None: + maxflow = flowMaxNorm(flow) + if maxmaxflow is not None: + maxflow = min(maxmaxflow, maxflow) + # normalize flow + eps = np.spacing(1) # minimum positive float value to avoid division by 0 + # compute the flow + img = _computeColor(flow/(maxflow+eps), saturate=saturate) + # put black pixels in unknown location + img[ np.tile( unknown_idx[:,:,np.newaxis],[1,1,3]) ] = 0.0 + return img + +def flowMaxNorm(flow): + """ + flow_utils.flowMaxNorm(flow): return the maximum of the l2-norm of the given flow + ---- PARAMETERS ---- + flow: the flow + + ---- OUTPUT ---- + a float containing the maximum of the l2-norm of the flow + """ + return np.max( np.sqrt( np.sum( np.square( flow ) , 2) ) ) + +def _computeColor(flow, saturate=True): + """ + flow_utils._computeColor(flow): compute color codes for the flow field flow + + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to display + ---- OUTPUTS ---- + an np.array of dimension (height x width x 3) containing the color conversion of the flow + """ + # set nan to 0 + nanidx = np.isnan(flow[:,:,0]) + flow[nanidx] = 0.0 + + # colorwheel + ncols = RY + YG + GC + CB + BM + MR + nchans = 3 + colorwheel = np.zeros((ncols,nchans),'uint8') + col = 0; + #RY + colorwheel[:RY,0] = 255 + colorwheel[:RY,1] = [(255*i) // RY for i in range(RY)] + col += RY + # YG + colorwheel[col:col+YG,0] = [255 - (255*i) // YG for i in range(YG)] + colorwheel[col:col+YG,1] = 255 + col += YG + # GC + colorwheel[col:col+GC,1] = 255 + colorwheel[col:col+GC,2] = [(255*i) // GC for i in range(GC)] + col += GC + # CB + colorwheel[col:col+CB,1] = [255 - (255*i) // CB for i in range(CB)] + colorwheel[col:col+CB,2] = 255 + col += CB + # BM + colorwheel[col:col+BM,0] = [(255*i) // BM for i in range(BM)] + colorwheel[col:col+BM,2] = 255 + col += BM + # MR + colorwheel[col:col+MR,0] = 255 + colorwheel[col:col+MR,2] = [255 - (255*i) // MR for i in range(MR)] + + # compute utility variables + rad = np.sqrt( np.sum( np.square(flow) , 2) ) # magnitude + a = np.arctan2( -flow[:,:,1] , -flow[:,:,0]) / np.pi # angle + fk = (a+1)/2 * (ncols-1) # map [-1,1] to [0,ncols-1] + k0 = np.floor(fk).astype('int') + k1 = k0+1 + k1[k1==ncols] = 0 + f = fk-k0 + + if not saturate: + rad = np.minimum(rad,1) + + # compute the image + img = np.zeros( (flow.shape[0],flow.shape[1],nchans), 'uint8' ) + for i in range(nchans): + tmp = colorwheel[:,i].astype('float') + col0 = tmp[k0]/255 + col1 = tmp[k1]/255 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1-rad[idx]*(1-col[idx]) # increase saturation with radius + col[~idx] *= 0.75 # out of range + img[:,:,i] = (255*col*(1-nanidx.astype('float'))).astype('uint8') + + return img + +# flow dataset getter + +def get_train_dataset_flow(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace('(','Dataset(') + if augmentor: + dataset_str = dataset_str.replace(')',', augmentor=True)') + if crop_size is not None: + dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) + return eval(dataset_str) + +def get_test_datasets_flow(dataset_str): + dataset_str = dataset_str.replace('(','Dataset(') + return [eval(s) for s in dataset_str.split('+')] \ No newline at end of file diff --git a/monst3r/croco/stereoflow/datasets_stereo.py b/monst3r/croco/stereoflow/datasets_stereo.py new file mode 100644 index 0000000..dbdf841 --- /dev/null +++ b/monst3r/croco/stereoflow/datasets_stereo.py @@ -0,0 +1,674 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Dataset structure for stereo +# -------------------------------------------------------- + +import sys, os +import os.path as osp +import pickle +import numpy as np +from PIL import Image +import json +import h5py +from glob import glob +import cv2 + +import torch +from torch.utils import data + +from .augmentor import StereoAugmentor + + + +dataset_to_root = { + 'CREStereo': './data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/', + 'SceneFlow': './data/stereoflow//SceneFlow/', + 'ETH3DLowRes': './data/stereoflow/eth3d_lowres/', + 'Booster': './data/stereoflow/booster_gt/', + 'Middlebury2021': './data/stereoflow/middlebury/2021/data/', + 'Middlebury2014': './data/stereoflow/middlebury/2014/', + 'Middlebury2006': './data/stereoflow/middlebury/2006/', + 'Middlebury2005': './data/stereoflow/middlebury/2005/train/', + 'MiddleburyEval3': './data/stereoflow/middlebury/MiddEval3/', + 'Spring': './data/stereoflow/spring/', + 'Kitti15': './data/stereoflow/kitti-stereo-2015/', + 'Kitti12': './data/stereoflow/kitti-stereo-2012/', +} +cache_dir = "./data/stereoflow/datasets_stereo_cache/" + + +in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) +in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) +def img_to_tensor(img): + img = torch.from_numpy(img).permute(2, 0, 1).float() / 255. + img = (img-in1k_mean)/in1k_std + return img +def disp_to_tensor(disp): + return torch.from_numpy(disp)[None,:,:] + +class StereoDataset(data.Dataset): + + def __init__(self, split, augmentor=False, crop_size=None, totensor=True): + self.split = split + if not augmentor: assert crop_size is None + if crop_size: assert augmentor + self.crop_size = crop_size + self.augmentor_str = augmentor + self.augmentor = StereoAugmentor(crop_size) if augmentor else None + self.totensor = totensor + self.rmul = 1 # keep track of rmul + self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len(self.pairnames) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + Limgname = self.pairname_to_Limgname(pairname) + Rimgname = self.pairname_to_Rimgname(pairname) + Ldispname = self.pairname_to_Ldispname(pairname) if self.pairname_to_Ldispname is not None else None + + # load images and disparities + Limg = _read_img(Limgname) + Rimg = _read_img(Rimgname) + disp = self.load_disparity(Ldispname) if Ldispname is not None else None + + # sanity check + if disp is not None: assert np.all(disp>0) or self.name=="Spring", (self.name, pairname, Ldispname) + + # apply augmentations + if self.augmentor is not None: + Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) + + if self.totensor: + Limg = img_to_tensor(Limg) + Rimg = img_to_tensor(Rimg) + if disp is None: + disp = torch.tensor([]) # to allow dataloader batching with default collate_gn + else: + disp = disp_to_tensor(disp) + + return Limg, Rimg, disp, str(pairname) + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f'{self.__class__.__name__}_{self.split}' + + def __repr__(self): + s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' + if self.rmul==1: + s+=f'\n\tnum pairs: {len(self.pairnames)}' + else: + s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name+'.pkl') + if osp.isfile(cache_file): + with open(cache_file, 'rb') as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, 'wb') as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + +class CREStereoDataset(StereoDataset): + + def _prepare_data(self): + self.name = 'CREStereo' + self._set_root() + assert self.split in ['train'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_left.jpg') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'_right.jpg') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname+'_left.disp.png') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_crestereo_disp + + + def _build_cache(self): + allpairs = [s+'/'+f[:-len('_left.jpg')] for s in sorted(os.listdir(self.root)) for f in sorted(os.listdir(self.root+'/'+s)) if f.endswith('_left.jpg')] + assert len(allpairs)==200000, "incorrect parsing of pairs in CreStereo" + tosave = {'train': allpairs} + return tosave + +class SceneFlowDataset(StereoDataset): + + def _prepare_data(self): + self.name = "SceneFlow" + self._set_root() + assert self.split in ['train_finalpass','train_cleanpass','train_allpass','test_finalpass','test_cleanpass','test_allpass','test1of100_cleanpass','test1of100_finalpass'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/left/','/right/') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname).replace('/frames_finalpass/','/disparity/').replace('/frames_cleanpass/','/disparity/')[:-4]+'.pfm' + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_sceneflow_disp + + def _build_cache(self): + trainpairs = [] + # driving + pairs = sorted(glob(self.root+'Driving/frames_finalpass/*/*/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # monkaa + pairs = sorted(glob(self.root+'Monkaa/frames_finalpass/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # flyingthings + pairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" + testpairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TEST/*/*/left/*.png')) + testpairs = list(map(lambda x: x[len(self.root):], testpairs)) + assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" + test1of100pairs = testpairs[::100] + assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" + # all + tosave = {'train_finalpass': trainpairs, + 'train_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), trainpairs)), + 'test_finalpass': testpairs, + 'test_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), testpairs)), + 'test1of100_finalpass': test1of100pairs, + 'test1of100_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), test1of100pairs)), + } + tosave['train_allpass'] = tosave['train_finalpass']+tosave['train_cleanpass'] + tosave['test_allpass'] = tosave['test_finalpass']+tosave['test_cleanpass'] + return tosave + +class Md21Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2021" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/im0','/im1')) + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp0.pfm') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + #trainpairs += [s+'/im0.png'] # we should remove it, it is included as such in other lightings + trainpairs += [s+'/ambient/'+b+'/'+a for b in sorted(os.listdir(osp.join(self.root,s,'ambient'))) for a in sorted(os.listdir(osp.join(self.root,s,'ambient',b))) if a.startswith('im0')] + assert len(trainpairs)==355 + subtrainpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[:-2])] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[-2:])] + assert len(subtrainpairs)==335 and len(subvalpairs)==20, "incorrect parsing of pairs in Middlebury 2021" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md14Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2014" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'disp0.pfm') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + trainpairs += [s+'/im1.png',s+'/im1E.png',s+'/im1L.png'] + assert len(trainpairs)==138 + valseqs = ['Umbrella-imperfect','Vintage-perfect'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==132 and len(subvalpairs)==6, "incorrect parsing of pairs in Middlebury 2014" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md06Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2006" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') + self.load_disparity = _read_middlebury20052006_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ['Illum1','Illum2','Illum3']: + for e in ['Exp0','Exp1','Exp2']: + trainpairs.append(osp.join(s,i,e,'view1.png')) + assert len(trainpairs)==189 + valseqs = ['Rocks1','Wood2'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==171 and len(subvalpairs)==18, "incorrect parsing of pairs in Middlebury 2006" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md05Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2005" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury20052006_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ['Illum1','Illum2','Illum3']: + for e in ['Exp0','Exp1','Exp2']: + trainpairs.append(osp.join(s,i,e,'view1.png')) + assert len(trainpairs)==54, "incorrect parsing of pairs in Middlebury 2005" + valseqs = ['Reindeer'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==45 and len(subvalpairs)==9, "incorrect parsing of pairs in Middlebury 2005" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class MdEval3Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "MiddleburyEval3" + self._set_root() + assert self.split in [s+'_'+r for s in ['train','subtrain','subval','test','all'] for r in ['full','half','quarter']] + if self.split.endswith('_full'): + self.root = self.root.replace('/MiddEval3','/MiddEval3_F') + elif self.split.endswith('_half'): + self.root = self.root.replace('/MiddEval3','/MiddEval3_H') + else: + assert self.split.endswith('_quarter') + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') + self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname, 'disp0GT.pfm') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_middlebury_disp + # for submission only + self.submission_methodname = "CroCo-Stereo" + self.submission_sresolution = 'F' if self.split.endswith('_full') else ('H' if self.split.endswith('_half') else 'Q') + + def _build_cache(self): + trainpairs = ['train/'+s for s in sorted(os.listdir(self.root+'train/'))] + testpairs = ['test/'+s for s in sorted(os.listdir(self.root+'test/'))] + subvalpairs = trainpairs[-1:] + subtrainpairs = trainpairs[:-1] + allpairs = trainpairs+testpairs + assert len(trainpairs)==15 and len(testpairs)==15 and len(subvalpairs)==1 and len(subtrainpairs)==14 and len(allpairs)==30, "incorrect parsing of pairs in Middlebury Eval v3" + tosave = {} + for r in ['full','half','quarter']: + tosave.update(**{'train_'+r: trainpairs, 'subtrain_'+r: subtrainpairs, 'subval_'+r: subvalpairs, 'test_'+r: testpairs, 'all_'+r: allpairs}) + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname.split('/')[0].replace('train','training')+self.submission_sresolution, pairname.split('/')[1], 'disp0'+self.submission_methodname+'.pfm') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = os.path.join( os.path.dirname(outfile), "time"+self.submission_methodname+'.txt') + with open(timefile, 'w') as fid: + fid.write(str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/{self.submission_methodname}.zip') + +class ETH3DLowResDataset(StereoDataset): + + def _prepare_data(self): + self.name = "ETH3DLowRes" + self._set_root() + assert self.split in ['train','test','subtrain','subval','all'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: None if pairname.startswith('test/') else osp.join(self.root, pairname.replace('train/','train_gt/'), 'disp0GT.pfm') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_eth3d_disp + self.has_constant_resolution = False + + def _build_cache(self): + trainpairs = ['train/' + s for s in sorted(os.listdir(self.root+'train/'))] + testpairs = ['test/' + s for s in sorted(os.listdir(self.root+'test/'))] + assert len(trainpairs) == 27 and len(testpairs) == 20, "incorrect parsing of pairs in ETH3D Low Res" + subvalpairs = ['train/delivery_area_3s','train/electro_3l','train/playground_3l'] + assert all(p in trainpairs for p in subvalpairs) + subtrainpairs = [p for p in trainpairs if not p in subvalpairs] + assert len(subvalpairs)==3 and len(subtrainpairs)==24, "incorrect parsing of pairs in ETH3D Low Res" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs, 'all': trainpairs+testpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, 'low_res_two_view', pairname.split('/')[1]+'.pfm') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = outfile[:-4]+'.txt' + with open(timefile, 'w') as fid: + fid.write('runtime '+str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip') + +class BoosterDataset(StereoDataset): + + def _prepare_data(self): + self.name = "Booster" + self._set_root() + assert self.split in ['train_balanced','test_balanced','subtrain_balanced','subval_balanced'] # we use only the balanced version + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/camera_00/','/camera_02/') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), '../disp_00.npy') # same images with different colors, same gt per sequence + self.pairname_to_str = lambda pairname: pairname[:-4].replace('/camera_00/','/') + self.load_disparity = _read_booster_disp + + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root+'train/balanced')) + trainpairs = ['train/balanced/'+s+'/camera_00/'+imname for s in trainseqs for imname in sorted(os.listdir(self.root+'train/balanced/'+s+'/camera_00/'))] + testpairs = ['test/balanced/'+s+'/camera_00/'+imname for s in sorted(os.listdir(self.root+'test/balanced')) for imname in sorted(os.listdir(self.root+'test/balanced/'+s+'/camera_00/'))] + assert len(trainpairs) == 228 and len(testpairs) == 191 + subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] + subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] + # warning: if we do validation split, we should split scenes!!! + tosave = {'train_balanced': trainpairs, 'test_balanced': testpairs, 'subtrain_balanced': subtrainpairs, 'subval_balanced': subvalpairs,} + return tosave + +class SpringDataset(StereoDataset): + + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ['train', 'test', 'subtrain', 'subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'.png').replace('frame_right','').replace('frame_left','frame_right').replace('','frame_left') + self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_hdf5_disp + + def _build_cache(self): + trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) + trainpairs = [osp.join('train',s,'frame_left',f[:-4]) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,'frame_left')))] + testseqs = sorted(os.listdir( osp.join(self.root,'test'))) + testpairs = [osp.join('test',s,'frame_left',f[:-4]) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,'frame_left')))] + testpairs += [p.replace('frame_left','frame_right') for p in testpairs] + """maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" + subtrainpairs = [p for p in trainpairs if p.split('/')[1]!='0041'] + subvalpairs = [p for p in trainpairs if p.split('/')[1]=='0041'] + assert len(trainpairs)==5000 and len(testpairs)==2000 and len(subtrainpairs)==4904 and len(subvalpairs)==96, "incorrect parsing of pairs in Spring" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeDsp5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + exe = "{self.root}/disp1_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + else: + print('Could not find disp1_subsampling executable for submission.') + print('Please download it and run:') + print(f'cd "{outdir}/test"; .') + +class Kitti12Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ['train','test'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/colored_1/')+'_10.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/disp_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] + testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] + assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" + tosave = {'train': trainseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype('uint16') + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti12_results.zip') + +class Kitti15Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ['train','subtrain','subval','test'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/image_3/')+'_10.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/disp_occ_0/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] + subtrainseqs = trainseqs[:-5] + subvalseqs = trainseqs[-5:] + testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] + assert len(trainseqs)==200 and len(subtrainseqs)==195 and len(subvalseqs)==5 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" + tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, 'disp_0', pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype('uint16') + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti15_results.zip') + + +### auxiliary functions + +def _read_img(filename): + # convert to RGB for scene flow finalpass data + img = np.asarray(Image.open(filename).convert('RGB')) + return img + +def _read_booster_disp(filename): + disp = np.load(filename) + disp[disp==0.0] = np.inf + return disp + +def _read_png_disp(filename, coef=1.0): + disp = np.asarray(Image.open(filename)) + disp = disp.astype(np.float32) / coef + disp[disp==0.0] = np.inf + return disp + +def _read_pfm_disp(filename): + disp = np.ascontiguousarray(_read_pfm(filename)[0]) + disp[disp<=0] = np.inf # eg /nfs/data/ffs-3d/datasets/middlebury/2014/Shopvac-imperfect/disp0.pfm + return disp + +def _read_npy_disp(filename): + return np.load(filename) + +def _read_crestereo_disp(filename): return _read_png_disp(filename, coef=32.0) +def _read_middlebury20052006_disp(filename): return _read_png_disp(filename, coef=1.0) +def _read_kitti_disp(filename): return _read_png_disp(filename, coef=256.0) +_read_sceneflow_disp = _read_pfm_disp +_read_eth3d_disp = _read_pfm_disp +_read_middlebury_disp = _read_pfm_disp +_read_carla_disp = _read_pfm_disp +_read_tartanair_disp = _read_npy_disp + +def _read_hdf5_disp(filename): + disp = np.asarray(h5py.File(filename)['disparity']) + disp[np.isnan(disp)] = np.inf # make invalid values as +inf + #disp[disp==0.0] = np.inf # make invalid values as +inf + return disp.astype(np.float32) + +import re +def _read_pfm(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + +def writeDsp5File(disp, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) + + +# disp visualization + +def vis_disparity(disp, m=None, M=None): + if m is None: m = disp.min() + if M is None: M = disp.max() + disp_vis = (disp - m) / (M-m) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + return disp_vis + +# dataset getter + +def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace('(','Dataset(') + if augmentor: + dataset_str = dataset_str.replace(')',', augmentor=True)') + if crop_size is not None: + dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) + return eval(dataset_str) + +def get_test_datasets_stereo(dataset_str): + dataset_str = dataset_str.replace('(','Dataset(') + return [eval(s) for s in dataset_str.split('+')] \ No newline at end of file diff --git a/monst3r/croco/stereoflow/download_model.sh b/monst3r/croco/stereoflow/download_model.sh new file mode 100644 index 0000000..5331196 --- /dev/null +++ b/monst3r/croco/stereoflow/download_model.sh @@ -0,0 +1,12 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +model=$1 +outfile="stereoflow_models/${model}" +if [[ ! -f $outfile ]] +then + mkdir -p stereoflow_models/; + wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/$1 -P stereoflow_models/; +else + echo "Model ${model} already downloaded in ${outfile}." +fi \ No newline at end of file diff --git a/monst3r/croco/stereoflow/engine.py b/monst3r/croco/stereoflow/engine.py new file mode 100644 index 0000000..c057346 --- /dev/null +++ b/monst3r/croco/stereoflow/engine.py @@ -0,0 +1,280 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main function for training one epoch or testing +# -------------------------------------------------------- + +import math +import sys +from typing import Iterable +import numpy as np +import torch +import torchvision + +from utils import misc as misc + + +def split_prediction_conf(predictions, with_conf=False): + if not with_conf: + return predictions, None + conf = predictions[:,-1:,:,:] + predictions = predictions[:,:-1,:,:] + return predictions, conf + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + log_writer=None, print_freq = 20, + args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + + accum_iter = args.accum_iter + + optimizer.zero_grad() + + details = {} + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if args.img_per_epoch: + iter_per_epoch = args.img_per_epoch // args.batch_size + int(args.img_per_epoch % args.batch_size > 0) + assert len(data_loader) >= iter_per_epoch, 'Dataset is too small for so many iterations' + len_data_loader = iter_per_epoch + else: + len_data_loader, iter_per_epoch = len(data_loader), None + + for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_logger.log_every(data_loader, print_freq, header, max_iter=iter_per_epoch)): + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, data_iter_step / len_data_loader + epoch, args) + + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + prediction = model(image1, image2) + prediction, conf = split_prediction_conf(prediction, criterion.with_conf) + batch_metrics = metrics(prediction.detach(), gt) + loss = criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf) + + loss_value = loss.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + for k,v in batch_metrics.items(): + metric_logger.update(**{k: v.item()}) + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + #if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) + time_to_log = ((data_iter_step + 1) % (args.tboard_log_step * accum_iter) == 0 or data_iter_step == len_data_loader-1) + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and time_to_log: + epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) + # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. + log_writer.add_scalar('train/loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('lr', lr, epoch_1000x) + for k,v in batch_metrics.items(): + log_writer.add_scalar('train/'+k, v.item(), epoch_1000x) + + # gather the stats from all processes + #if args.distributed: metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def validate_one_epoch(model: torch.nn.Module, + criterion: torch.nn.Module, + metrics: torch.nn.Module, + data_loaders: list[Iterable], + device: torch.device, + epoch: int, + log_writer=None, + args=None): + + model.eval() + metric_loggers = [] + header = 'Epoch: [{}]'.format(epoch) + print_freq = 20 + + conf_mode = args.tile_conf_mode + crop = args.crop + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + results = {} + dnames = [] + image1, image2, gt, prediction = None, None, None, None + for didx, data_loader in enumerate(data_loaders): + dname = str(data_loader.dataset) + dnames.append(dname) + metric_loggers.append(misc.MetricLogger(delimiter=" ")) + for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_loggers[didx].log_every(data_loader, print_freq, header)): + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + if dname.startswith('Spring'): + assert gt.size(2)==image1.size(2)*2 and gt.size(3)==image1.size(3)*2 + gt = (gt[:,:,0::2,0::2] + gt[:,:,0::2,1::2] + gt[:,:,1::2,0::2] + gt[:,:,1::2,1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones + + with torch.inference_mode(): + prediction, tiled_loss, c = tiled_pred(model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf) + batch_metrics = metrics(prediction.detach(), gt) + loss = criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c) + loss_value = loss.item() + metric_loggers[didx].update(loss_tiled=tiled_loss.item()) + metric_loggers[didx].update(**{f'loss': loss_value}) + for k,v in batch_metrics.items(): + metric_loggers[didx].update(**{dname+'_' + k: v.item()}) + + results = {k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()} + if len(dnames)>1: + for k in batch_metrics.keys(): + results['AVG_'+k] = sum(results[dname+'_'+k] for dname in dnames) / len(dnames) + + if log_writer is not None : + epoch_1000x = int((1 + epoch) * 1000) + for k,v in results.items(): + log_writer.add_scalar('val/'+k, v, epoch_1000x) + + print("Averaged stats:", results) + return results + +import torch.nn.functional as F +def _resize_img(img, new_size): + return F.interpolate(img, size=new_size, mode='bicubic', align_corners=False) +def _resize_stereo_or_flow(data, new_size): + assert data.ndim==4 + assert data.size(1) in [1,2] + scale_x = new_size[1]/float(data.size(3)) + out = F.interpolate(data, size=new_size, mode='bicubic', align_corners=False) + out[:,0,:,:] *= scale_x + if out.size(1)==2: + scale_y = new_size[0]/float(data.size(2)) + out[:,1,:,:] *= scale_y + print(scale_x, new_size, data.shape) + return out + + +@torch.no_grad() +def tiled_pred(model, criterion, img1, img2, gt, + overlap=0.5, bad_crop_thr=0.05, + downscale=False, crop=512, ret='loss', + conf_mode='conf_expsigmoid_10_5', with_conf=False, + return_time=False): + + # for each image, we are going to run inference on many overlapping patches + # then, all predictions will be weighted-averaged + if gt is not None: + B, C, H, W = gt.shape + else: + B, _, H, W = img1.shape + C = model.head.num_channels-int(with_conf) + win_height, win_width = crop[0], crop[1] + + # upscale to be larger than the crop + do_change_scale = H= window and 0 <= overlap < 1, (total, window, overlap) + num_windows = 1 + int(np.ceil( (total - window) / ((1-overlap) * window) )) + offsets = np.linspace(0, total-window, num_windows).round().astype(int) + yield from (slice(x, x+window) for x in offsets) + +def _crop(img, sy, sx): + B, THREE, H, W = img.shape + if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: + return img[:,:,sy,sx] + l, r = max(0,-sx.start), max(0,sx.stop-W) + t, b = max(0,-sy.start), max(0,sy.stop-H) + img = torch.nn.functional.pad(img, (l,r,t,b), mode='constant') + return img[:, :, slice(sy.start+t,sy.stop+t), slice(sx.start+l,sx.stop+l)] \ No newline at end of file diff --git a/monst3r/croco/stereoflow/test.py b/monst3r/croco/stereoflow/test.py new file mode 100644 index 0000000..0248e56 --- /dev/null +++ b/monst3r/croco/stereoflow/test.py @@ -0,0 +1,216 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main test function +# -------------------------------------------------------- + +import os +import argparse +import pickle +from PIL import Image +import numpy as np +from tqdm import tqdm + +import torch +from torch.utils.data import DataLoader + +import utils.misc as misc +from models.croco_downstream import CroCoDownstreamBinocular +from models.head_downstream import PixelwiseTaskWithDPT + +from stereoflow.criterion import * +from stereoflow.datasets_stereo import get_test_datasets_stereo +from stereoflow.datasets_flow import get_test_datasets_flow +from stereoflow.engine import tiled_pred + +from stereoflow.datasets_stereo import vis_disparity +from stereoflow.datasets_flow import flowToColor + +def get_args_parser(): + parser = argparse.ArgumentParser('Test CroCo models on stereo/flow', add_help=False) + # important argument + parser.add_argument('--model', required=True, type=str, help='Path to the model to evaluate') + parser.add_argument('--dataset', required=True, type=str, help="test dataset (there can be multiple dataset separated by a +)") + # tiling + parser.add_argument('--tile_conf_mode', type=str, default='', help='Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint') + parser.add_argument('--tile_overlap', type=float, default=0.7, help='overlap between tiles') + # save (it will automatically go to _/_) + parser.add_argument('--save', type=str, nargs='+', default=[], + help='what to save: \ + metrics (pickle file), \ + pred (raw prediction save as torch tensor), \ + visu (visualization in png of each prediction), \ + err10 (visualization in png of the error clamp at 10 for each prediction), \ + submission (submission file)') + # other (no impact) + parser.add_argument('--num_workers', default=4, type=int) + return parser + + +def _load_model_and_criterion(model_path, do_load_metrics, device): + print('loading model from', model_path) + assert os.path.isfile(model_path) + ckpt = torch.load(model_path, 'cpu') + + ckpt_args = ckpt['args'] + task = ckpt_args.task + tile_conf_mode = ckpt_args.tile_conf_mode + num_channels = {'stereo': 1, 'flow': 2}[task] + with_conf = eval(ckpt_args.criterion).with_conf + if with_conf: num_channels += 1 + print('head: PixelwiseTaskWithDPT()') + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + print('croco_args:', ckpt_args.croco_args) + model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) + msg = model.load_state_dict(ckpt['model'], strict=True) + model.eval() + model = model.to(device) + + if do_load_metrics: + if task=='stereo': + metrics = StereoDatasetMetrics().to(device) + else: + metrics = FlowDatasetMetrics().to(device) + else: + metrics = None + + return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode + + +def _save_batch(pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None): + + for i in range(len(pairnames)): + + pairname = eval(pairnames[i]) if pairnames[i].startswith('(') else pairnames[i] # unbatch pairname + fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) + os.makedirs(os.path.dirname(fname), exist_ok=True) + + predi = pred[i,...] + if gt is not None: gti = gt[i,...] + + if 'pred' in save: + torch.save(predi.squeeze(0).cpu(), fname+'_pred.pth') + + if 'visu' in save: + if task=='stereo': + disparity = predi.permute((1,2,0)).squeeze(2).cpu().numpy() + m,M = None + if gt is not None: + mask = torch.isfinite(gti) + m = gt[mask].min() + M = gt[mask].max() + img_disparity = vis_disparity(disparity, m=m, M=M) + Image.fromarray(img_disparity).save(fname+'_pred.png') + else: + # normalize flowToColor according to the maxnorm of gt (or prediction if not available) + flowNorm = torch.sqrt(torch.sum( (gti if gt is not None else predi)**2, dim=0)).max().item() + imgflow = flowToColor(predi.permute((1,2,0)).cpu().numpy(), maxflow=flowNorm) + Image.fromarray(imgflow).save(fname+'_pred.png') + + if 'err10' in save: + assert gt is not None + L2err = torch.sqrt(torch.sum( (gti-predi)**2, dim=0)) + valid = torch.isfinite(gti[0,:,:]) + L2err[~valid] = 0.0 + L2err = torch.clamp(L2err, max=10.0) + red = (L2err*255.0/10.0).to(dtype=torch.uint8)[:,:,None] + zer = torch.zeros_like(red) + imgerr = torch.cat( (red,zer,zer), dim=2).cpu().numpy() + Image.fromarray(imgerr).save(fname+'_err10.png') + + if 'submission' in save: + assert submission_dir is not None + predi_np = predi.permute(1,2,0).squeeze(2).cpu().numpy() # transform into HxWx2 for flow or HxW for stereo + dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) + +def main(args): + + # load the pretrained model and metrics + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + model, metrics, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion(args.model, 'metrics' in args.save, device) + if args.tile_conf_mode=='': args.tile_conf_mode = tile_conf_mode + + # load the datasets + datasets = (get_test_datasets_stereo if task=='stereo' else get_test_datasets_flow)(args.dataset) + dataloaders = [DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for dataset in datasets] + + # run + for i,dataloader in enumerate(dataloaders): + dataset = datasets[i] + dstr = args.dataset.split('+')[i] + + outdir = args.model+'_'+misc.filename(dstr) + if 'metrics' in args.save and len(args.save)==1: + fname = os.path.join(outdir, f'conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl') + if os.path.isfile(fname) and len(args.save)==1: + print(' metrics already compute in '+fname) + with open(fname, 'rb') as fid: + results = pickle.load(fid) + for k,v in results.items(): + print('{:s}: {:.3f}'.format(k, v)) + continue + + if 'submission' in args.save: + dirname = f'submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}' + submission_dir = os.path.join(outdir, dirname) + else: + submission_dir = None + + print('') + print('saving {:s} in {:s}'.format('+'.join(args.save), outdir)) + print(repr(dataset)) + + if metrics is not None: + metrics.reset() + + for data_iter_step, (image1, image2, gt, pairnames) in enumerate(tqdm(dataloader)): + + do_flip = (task=='stereo' and dstr.startswith('Spring') and any("right" in p for p in pairnames)) # we flip the images and will flip the prediction after as we assume img1 is on the left + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) if gt.numel()>0 else None # special case for test time + if do_flip: + assert all("right" in p for p in pairnames) + image1 = image1.flip(dims=[3]) # this is already the right frame, let's flip it + image2 = image2.flip(dims=[3]) + gt = gt # that is ok + + with torch.inference_mode(): + pred, _, _, time = tiled_pred(model, None, image1, image2, None if dataset.name=='Spring' else gt, conf_mode=args.tile_conf_mode, overlap=args.tile_overlap, crop=cropsize, with_conf=with_conf, return_time=True) + + if do_flip: + pred = pred.flip(dims=[3]) + + if metrics is not None: + metrics.add_batch(pred, gt) + + if any(k in args.save for k in ['pred','visu','err10','submission']): + _save_batch(pred, gt, pairnames, dataset, task, args.save, outdir, time, submission_dir=submission_dir) + + + # print + if metrics is not None: + results = metrics.get_results() + for k,v in results.items(): + print('{:s}: {:.3f}'.format(k, v)) + + # save if needed + if 'metrics' in args.save: + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, 'wb') as fid: + pickle.dump(results, fid) + print('metrics saved in', fname) + + # finalize submission if needed + if 'submission' in args.save: + dataset.finalize_submission(submission_dir) + + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/monst3r/croco/stereoflow/train.py b/monst3r/croco/stereoflow/train.py new file mode 100644 index 0000000..91f2414 --- /dev/null +++ b/monst3r/croco/stereoflow/train.py @@ -0,0 +1,253 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main training function +# -------------------------------------------------------- + +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time + +import torch +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from torch.utils.data import DataLoader + +import utils +import utils.misc as misc +from utils.misc import NativeScalerWithGradNormCount as NativeScaler +from models.croco_downstream import CroCoDownstreamBinocular, croco_args_from_ckpt +from models.pos_embed import interpolate_pos_embed +from models.head_downstream import PixelwiseTaskWithDPT + +from stereoflow.datasets_stereo import get_train_dataset_stereo, get_test_datasets_stereo +from stereoflow.datasets_flow import get_train_dataset_flow, get_test_datasets_flow +from stereoflow.engine import train_one_epoch, validate_one_epoch +from stereoflow.criterion import * + + +def get_args_parser(): + # prepare subparsers + parser = argparse.ArgumentParser('Finetuning CroCo models on stereo or flow', add_help=False) + subparsers = parser.add_subparsers(title="Task (stereo or flow)", dest="task", required=True) + parser_stereo = subparsers.add_parser('stereo', help='Training stereo model') + parser_flow = subparsers.add_parser('flow', help='Training flow model') + def add_arg(name_or_flags, default=None, default_stereo=None, default_flow=None, **kwargs): + if default is not None: assert default_stereo is None and default_flow is None, "setting default makes default_stereo and default_flow disabled" + parser_stereo.add_argument(name_or_flags, default=default if default is not None else default_stereo, **kwargs) + parser_flow.add_argument(name_or_flags, default=default if default is not None else default_flow, **kwargs) + # output dir + add_arg('--output_dir', required=True, type=str, help='path where to save, if empty, automatically created') + # model + add_arg('--crop', type=int, nargs = '+', default_stereo=[352, 704], default_flow=[320, 384], help = "size of the random image crops used during training.") + add_arg('--pretrained', required=True, type=str, help="Load pretrained model (required as croco arguments come from there)") + # criterion + add_arg('--criterion', default_stereo='LaplacianLossBounded2()', default_flow='LaplacianLossBounded()', type=str, help='string to evaluate to get criterion') + add_arg('--bestmetric', default_stereo='avgerr', default_flow='EPE', type=str) + # dataset + add_arg('--dataset', type=str, required=True, help="training set") + # training + add_arg('--seed', default=0, type=int, help='seed') + add_arg('--batch_size', default_stereo=6, default_flow=8, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + add_arg('--epochs', default=32, type=int, help='number of training epochs') + add_arg('--img_per_epoch', type=int, default=None, help='Fix the number of images seen in an epoch (None means use all training pairs)') + add_arg('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + add_arg('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') + add_arg('--lr', type=float, default_stereo=3e-5, default_flow=2e-5, metavar='LR', help='learning rate (absolute lr)') + add_arg('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') + add_arg('--warmup_epochs', type=int, default=1, metavar='N', help='epochs to warmup LR') + add_arg('--optimizer', default='AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))', type=str, + help="Optimizer from torch.optim [ default: AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) ]") + add_arg('--amp', default=0, type=int, choices=[0,1], help='enable automatic mixed precision training') + # validation + add_arg('--val_dataset', type=str, default='', help="Validation sets, multiple separated by + (empty string means that no validation is performed)") + add_arg('--tile_conf_mode', type=str, default_stereo='conf_expsigmoid_15_3', default_flow='conf_expsigmoid_10_5', help='Weights for tile aggregation') + add_arg('--val_overlap', default=0.7, type=float, help='Overlap value for the tiling') + # others + add_arg('--num_workers', default=8, type=int) + add_arg('--eval_every', type=int, default=1, help='Val loss evaluation frequency') + add_arg('--save_every', type=int, default=1, help='Save checkpoint frequency') + add_arg('--start_from', type=str, default=None, help='Start training using weights from an other model (eg for finetuning)') + add_arg('--tboard_log_step', type=int, default=100, help='Log to tboard every so many steps') + add_arg('--dist_url', default='env://', help='url used to set up distributed training') + + return parser + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + num_tasks = misc.get_world_size() + + assert os.path.isfile(args.pretrained) + print("output_dir: "+args.output_dir) + os.makedirs(args.output_dir, exist_ok=True) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + # Metrics / criterion + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + metrics = (StereoMetrics if args.task=='stereo' else FlowMetrics)().to(device) + criterion = eval(args.criterion).to(device) + print('Criterion: ', args.criterion) + + # Prepare model + assert os.path.isfile(args.pretrained) + ckpt = torch.load(args.pretrained, 'cpu') + croco_args = croco_args_from_ckpt(ckpt) + croco_args['img_size'] = (args.crop[0], args.crop[1]) + print('Croco args: '+str(croco_args)) + args.croco_args = croco_args # saved for test time + # prepare head + num_channels = {'stereo': 1, 'flow': 2}[args.task] + if criterion.with_conf: num_channels += 1 + print(f'Building head PixelwiseTaskWithDPT() with {num_channels} channel(s)') + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + # build model and load pretrained weights + model = CroCoDownstreamBinocular(head, **croco_args) + interpolate_pos_embed(model, ckpt['model']) + msg = model.load_state_dict(ckpt['model'], strict=False) + print(msg) + + total_params = sum(p.numel() for p in model.parameters()) + total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Total params: {total_params}") + print(f"Total params trainable: {total_params_trainable}") + model_without_ddp = model.to(device) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + print("lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], static_graph=True) + model_without_ddp = model.module + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) + optimizer = eval(f"torch.optim.{args.optimizer}") + print(optimizer) + loss_scaler = NativeScaler() + + # automatic restart + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + if not args.resume and args.start_from: + print(f"Starting from an other model's weights: {args.start_from}") + best_so_far = None + args.start_epoch = 0 + ckpt = torch.load(args.start_from, 'cpu') + msg = model_without_ddp.load_state_dict(ckpt['model'], strict=False) + print(msg) + else: + best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if best_so_far is None: best_so_far = np.inf + + # tensorboard + log_writer = None + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir, purge_step=args.start_epoch*1000) + + # dataset and loader + print('Building Train Data loader for dataset: ', args.dataset) + train_dataset = (get_train_dataset_stereo if args.task=='stereo' else get_train_dataset_flow)(args.dataset, crop_size=args.crop) + def _print_repr_dataset(d): + if isinstance(d, torch.utils.data.dataset.ConcatDataset): + for dd in d.datasets: + _print_repr_dataset(dd) + else: + print(repr(d)) + _print_repr_dataset(train_dataset) + print(' total length:', len(train_dataset)) + if args.distributed: + sampler_train = torch.utils.data.DistributedSampler( + train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.RandomSampler(train_dataset) + data_loader_train = torch.utils.data.DataLoader( + train_dataset, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + if args.val_dataset=='': + data_loaders_val = None + else: + print('Building Val Data loader for datasets: ', args.val_dataset) + val_datasets = (get_test_datasets_stereo if args.task=='stereo' else get_test_datasets_flow)(args.val_dataset) + for val_dataset in val_datasets: print(repr(val_dataset)) + data_loaders_val = [DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for val_dataset in val_datasets] + bestmetric = ("AVG_" if len(data_loaders_val)>1 else str(data_loaders_val[0].dataset)+'_')+args.bestmetric + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + # Training Loop + for epoch in range(args.start_epoch, args.epochs): + + if args.distributed: data_loader_train.sampler.set_epoch(epoch) + + # Train + epoch_start = time.time() + train_stats = train_one_epoch(model, criterion, metrics, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args) + epoch_time = time.time() - epoch_start + + if args.distributed: dist.barrier() + + # Validation (current naive implementation runs the validation on every gpu ... not smart ...) + if data_loaders_val is not None and args.eval_every > 0 and (epoch+1) % args.eval_every == 0: + val_epoch_start = time.time() + val_stats = validate_one_epoch(model, criterion, metrics, data_loaders_val, device, epoch, log_writer=log_writer, args=args) + val_epoch_time = time.time() - val_epoch_start + + val_best = val_stats[bestmetric] + + # Save best of all + if val_best <= best_so_far: + best_so_far = val_best + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='best') + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + **{f'val_{k}': v for k, v in val_stats.items()}} + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch,} + + if args.distributed: dist.barrier() + + # Save stuff + if args.output_dir and ((epoch+1) % args.save_every == 0 or epoch + 1 == args.epochs): + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='last') + + if args.output_dir: + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/monst3r/croco/utils/misc.py b/monst3r/croco/utils/misc.py new file mode 100644 index 0000000..a7460c7 --- /dev/null +++ b/monst3r/croco/utils/misc.py @@ -0,0 +1,470 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +import math +import json +from collections import defaultdict, deque +from pathlib import Path +import numpy as np + +import torch +import torch.distributed as dist +from torch import inf + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, max_iter=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable) + space_fmt = ':' + str(len(str(len_iterable))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for it,obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len_iterable - 1: + eta_seconds = iter_time.global_avg * (len_iterable - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + if max_iter and it >= max_iter: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len_iterable)) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + nodist = args.nodist if hasattr(args,'nodist') else False + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and not nodist: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta(seconds=3600)) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self, enabled=True): + self._scaler = torch.cuda.amp.GradScaler(enabled=enabled) + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + + + +def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None, best_pose_ate_sofar=None): + output_dir = Path(args.output_dir) + if fname is None: fname = str(epoch) + checkpoint_path = output_dir / ('checkpoint-%s.pth' % fname) + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': loss_scaler.state_dict(), + 'args': args, + 'epoch': epoch, + } + if best_so_far is not None: to_save['best_so_far'] = best_so_far + if best_pose_ate_sofar is not None: to_save['best_pose_ate_sofar'] = best_pose_ate_sofar + print(f'>> Saving model to {checkpoint_path} ...') + save_on_master(to_save, checkpoint_path) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + args.start_epoch = 0 + best_so_far = None + best_pose_ate_sofar = None + if args.resume is not None: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + print("Resume checkpoint %s" % args.resume) + model_without_ddp.load_state_dict(checkpoint['model'], strict=False) + args.start_epoch = checkpoint['epoch'] + 1 + optimizer.load_state_dict(checkpoint['optimizer']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + if 'best_so_far' in checkpoint: + best_so_far = checkpoint['best_so_far'] + print(" & best_so_far={:g}".format(best_so_far)) + else: + print("") + if 'best_pose_ate_sofar' in checkpoint: + best_pose_ate_sofar = checkpoint['best_pose_ate_sofar'] + print(" & best_pose_ate_sofar={:g}".format(best_pose_ate_sofar)) + else: + best_pose_ate_sofar = None + print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end='') + return best_so_far, best_pose_ate_sofar + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + +def _replace(text, src, tgt, rm=''): + """ Advanced string replacement. + Given a text: + - replace all elements in src by the corresponding element in tgt + - remove all elements in rm + """ + if len(tgt) == 1: + tgt = tgt * len(src) + assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len" + for s,t in zip(src, tgt): + text = text.replace(s,t) + for c in rm: + text = text.replace(c,'') + return text + +def filename( obj ): + """ transform a python obj or cmd into a proper filename. + - \1 gets replaced by slash '/' + - \2 gets replaced by comma ',' + """ + if not isinstance(obj, str): + obj = repr(obj) + obj = str(obj).replace('()','') + obj = _replace(obj, '_,(*/\1\2','-__x%/,', rm=' )\'"') + assert all(len(s) < 256 for s in obj.split(os.sep)), 'filename too long (>256 characters):\n'+obj + return obj + +def _get_num_layer_for_vit(var_name, enc_depth, dec_depth): + if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("enc_blocks"): + layer_id = int(var_name.split('.')[1]) + return layer_id + 1 + elif var_name.startswith('decoder_embed') or var_name.startswith('enc_norm'): # part of the last black + return enc_depth + elif var_name.startswith('dec_blocks'): + layer_id = int(var_name.split('.')[1]) + return enc_depth + layer_id + 1 + elif var_name.startswith('dec_norm'): # part of the last block + return enc_depth + dec_depth + elif any(var_name.startswith(k) for k in ['head','prediction_head']): + return enc_depth + dec_depth + 1 + else: + raise NotImplementedError(var_name) + +def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]): + parameter_group_names = {} + parameter_group_vars = {} + enc_depth, dec_depth = None, None + # prepare layer decay values + assert layer_decay==1.0 or 0. [!NOTE] +> The scripts provided here are for reference only. Please ensure you have obtained the necessary licenses from the original dataset providers before proceeding. + + +## Download Datasets + +### Sintel +To download and prepare the **Sintel** dataset, execute: +```bash +cd data +bash download_sintel.sh +cd .. + +# (optional) generate the GT dynamic mask +cd .. +python datasets_preprocess/sintel_get_dynamics.py --threshold 0.1 --save_dir dynamic_label_perfect +``` + +### Bonn +To download and prepare the **Bonn** dataset, execute: +```bash +cd data +bash download_bonn.sh +cd .. + +# create the subset for video depth evaluation, following depthcrafter +cd datasets_preprocess +python prepare_bonn.py +cd .. +``` + +### KITTI +To download and prepare the **KITTI** dataset, execute: +```bash +cd data +bash download_kitti.sh +cd .. + +# create the subset for video depth evaluation, following depthcrafter +cd datasets_preprocess +python prepare_kitti.py +cd .. +``` + +### NYU-v2 +To download and prepare the **NYU-v2** dataset, execute: +```bash +cd data +bash download_nyuv2.sh +cd .. + +# prepare the dataset for depth evaluation +cd datasets_preprocess +python prepare_nyuv2.py +cd .. +``` + +### TUM-dynamics +To download and prepare the **TUM-dynamics** dataset, execute: +```bash +cd data +bash download_tum.sh +cd .. + +# prepare the dataset for pose evaluation +cd datasets_preprocess +python prepare_tum.py +cd .. +``` + +### ScanNet +To download and prepare the **ScanNet** dataset, execute: +```bash +cd data +bash download_scannetv2.sh +cd .. + +# prepare the dataset for pose evaluation +cd datasets_preprocess +python prepare_scannet.py +cd .. +``` + +### DAVIS +To download and prepare the **DAVIS** dataset, execute: +```bash +cd data +python download_davis.py +cd .. +``` + +## Evaluation Script (Video Depth) + +### Sintel + +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_pose \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=sintel --output_dir="results/sintel_video_depth" --full_seq --no_crop +``` + +The results will be saved in the `results/sintel_video_depth` folder. You could then run the corresponding code block in [depth_metric.ipynb](../depth_metric.ipynb) to evaluate the results. + +### Bonn + +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_pose \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=bonn --output_dir="results/bonn_video_depth" --no_crop +``` + +The results will be saved in the `results/bonn_video_depth` folder. You could then run the corresponding code block in [depth_metric.ipynb](../depth_metric.ipynb) to evaluate the results. + +### KITTI + +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_pose \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=kitti --output_dir="results/kitti_video_depth" --no_crop --flow_loss_weight 0 --translation_weight 1e-3 +# adjust flow loss weight and translation weight due to poor flow prediction result and large translation in KITTI +# updated hyperparameters should give better results of Abs_Rel = 0.089; δ<1.25 = 91.11 +``` + +The results will be saved in the `results/kitti_video_depth` folder. You could then run the corresponding code block in [depth_metric.ipynb](../depth_metric.ipynb) to evaluate the results. + +## Evaluation Script (Camera Pose) + +### Sintel + +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_pose \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=sintel --output_dir="results/sintel_pose" + # To use the ground truth dynamic mask, add: --use_gt_mask +``` + +The evaluation results will be saved in `results/sintel_pose/_error_log.txt`. + +### TUM-dynamics + +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_pose \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=tum --output_dir="results/tum_pose" +``` + +The evaluation results will be saved in `results/tum_pose/_error_log.txt`. + +### ScanNet + +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_pose \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=scannet --output_dir="results/scannet_pose" +``` + +The evaluation results will be saved in `results/scannet_pose/_error_log.txt`. + +## Evaluation Script (Single-Frame Depth) + +### NYU-v2 + +```bash +CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29604 launch.py --mode=eval_depth \ + --pretrained="checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth" \ + --eval_dataset=nyu --output_dir="results/nyuv2_depth" --no_crop +``` + +The results will be saved in the `results/nyuv2_depth` folder. You could then run the corresponding code block in [depth_metric.ipynb](../depth_metric.ipynb) to evaluate the results. diff --git a/monst3r/data/prepare_training.md b/monst3r/data/prepare_training.md new file mode 100644 index 0000000..cdc667f --- /dev/null +++ b/monst3r/data/prepare_training.md @@ -0,0 +1,73 @@ + +# Dataset Preparation for Training + +We provide scripts to prepare datasets for training, including **PointOdyssey**, **TartanAir**, **Spring**, and **Waymo**. For evaluation, we also provide a script for preparing the **Sintel** dataset. + +> [!NOTE] +> The scripts provided here are for reference only. Please ensure you have obtained the necessary licenses from the original dataset providers before proceeding. + +## Download Pre-Trained Models +To download the pre-trained models, run the following commands: +```bash +cd data +wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P ../checkpoints/ +cd .. +``` + +## Dataset Setup + +### PointOdyssey +To download and prepare the **PointOdyssey** dataset, execute: +```bash +cd data +bash download_pointodyssey.sh +cd .. +``` + +### TartanAir +To download and prepare the **TartanAir** dataset, execute: +```bash +cd data +bash download_tartanair.sh +cd .. +``` + +### Spring +To download and prepare the **Spring** dataset, execute: +```bash +cd data +bash download_spring.sh +cd .. +``` + +### Waymo +To download and prepare the **Waymo** dataset, follow these steps: + +1. Set up Google Cloud SDK (if you haven't done so already): +```bash +curl https://sdk.cloud.google.com | bash +exec -l $SHELL +gcloud init +gcloud auth login +``` + +2. Download the Waymo dataset: +```bash +cd data +bash download_waymo.sh +cd .. +``` + +3. Preprocess the dataset and create training pairs: +```bash +python datasets_preprocess/preprocess_waymo.py +python datasets_preprocess/waymo_make_pairs.py +``` + +## Sintel (Evaluation) +To download and prepare the **Sintel** dataset for evaluation, execute: +```bash +cd data +bash download_sintel.sh +cd .. +``` diff --git a/monst3r/datasets_preprocess/habitat/README.md b/monst3r/datasets_preprocess/habitat/README.md new file mode 100644 index 0000000..3a24120 --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/README.md @@ -0,0 +1,66 @@ +## Steps to reproduce synthetic training data using the Habitat-Sim simulator + +### Create a conda environment +```bash +conda create -n habitat python=3.8 habitat-sim=0.2.1 headless=2.0 -c aihabitat -c conda-forge +conda active habitat +conda install pytorch -c pytorch +pip install opencv-python tqdm +``` + +or (if you get the error `For headless systems, compile with --headless for EGL support`) +``` +git clone --branch stable https://github.com/facebookresearch/habitat-sim.git +cd habitat-sim + +conda create -n habitat python=3.9 cmake=3.14.0 +conda activate habitat +pip install . -v +conda install pytorch -c pytorch +pip install opencv-python tqdm +``` + +### Download Habitat-Sim scenes +Download Habitat-Sim scenes: +- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md +- We used scenes from the HM3D, habitat-test-scenes, ReplicaCad and ScanNet datasets. +- Please put the scenes in a directory `$SCENES_DIR` following the structure below: +(Note: the habitat-sim dataset installer may install an incompatible version for ReplicaCAD backed lighting. +The correct scene dataset can be dowloaded from Huggingface: `git clone git@hf.co:datasets/ai-habitat/ReplicaCAD_baked_lighting`). +``` +$SCENES_DIR/ +├──hm3d/ +├──gibson/ +├──habitat-test-scenes/ +├──ReplicaCAD_baked_lighting/ +└──scannet/ +``` + +### Download renderings metadata + +Download metadata corresponding to each scene and extract them into a directory `$METADATA_DIR` +```bash +wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz +tar -xvzf habitat_5views_v1_512x512_metadata.tar.gz +``` + +### Render the scenes + +Render the scenes in an output directory `$OUTPUT_DIR` +```bash +export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata" +export SCENES_DIR="/path/to/habitat/data/scene_datasets/" +export OUTPUT_DIR="data/habitat_processed" +cd datasets_preprocess/habitat/ +export PYTHONPATH=$(pwd) +# Print commandlines to generate images corresponding to each scene +python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR +# Launch these commandlines in parallel e.g. using GNU-Parallel as follows: +python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16 +``` + +### Make a list of scenes + +```bash +python find_scenes.py --root $OUTPUT_DIR +``` \ No newline at end of file diff --git a/monst3r/datasets_preprocess/habitat/find_scenes.py b/monst3r/datasets_preprocess/habitat/find_scenes.py new file mode 100644 index 0000000..1a26349 --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/find_scenes.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to export the list of scenes for habitat (after having rendered them). +# Usage: +# python3 datasets_preprocess/preprocess_co3d.py --root data/habitat_processed +# -------------------------------------------------------- +import numpy as np +import os +from collections import defaultdict +from tqdm import tqdm + + +def find_all_scenes(habitat_root, n_scenes=[100000]): + np.random.seed(777) + + try: + fpath = os.path.join(habitat_root, f'Habitat_all_scenes.txt') + list_subscenes = open(fpath).read().splitlines() + + except IOError: + if input('parsing sub-folders to find scenes? (y/n) ') != 'y': + return + list_subscenes = [] + for root, dirs, files in tqdm(os.walk(habitat_root)): + for f in files: + if not f.endswith('_1_depth.exr'): + continue + scene = os.path.join(os.path.relpath(root, habitat_root), f.replace('_1_depth.exr', '')) + if hash(scene) % 1000 == 0: + print('... adding', scene) + list_subscenes.append(scene) + + with open(fpath, 'w') as f: + f.write('\n'.join(list_subscenes)) + print(f'>> wrote {fpath}') + + print(f'Loaded {len(list_subscenes)} sub-scenes') + + # separate scenes + list_scenes = defaultdict(list) + for scene in list_subscenes: + scene, id = os.path.split(scene) + list_scenes[scene].append(id) + + list_scenes = list(list_scenes.items()) + print(f'from {len(list_scenes)} scenes in total') + + np.random.shuffle(list_scenes) + train_scenes = list_scenes[len(list_scenes)//10:] + val_scenes = list_scenes[:len(list_scenes)//10] + + def write_scene_list(scenes, n, fpath): + sub_scenes = [os.path.join(scene, id) for scene, ids in scenes for id in ids] + np.random.shuffle(sub_scenes) + + if len(sub_scenes) < n: + return + + with open(fpath, 'w') as f: + f.write('\n'.join(sub_scenes[:n])) + print(f'>> wrote {fpath}') + + for n in n_scenes: + write_scene_list(train_scenes, n, os.path.join(habitat_root, f'Habitat_{n}_scenes_train.txt')) + write_scene_list(val_scenes, n//10, os.path.join(habitat_root, f'Habitat_{n//10}_scenes_val.txt')) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--root", required=True) + parser.add_argument("--n_scenes", nargs='+', default=[1_000, 1_000_000], type=int) + + args = parser.parse_args() + find_all_scenes(args.root, args.n_scenes) diff --git a/monst3r/datasets_preprocess/habitat/habitat_renderer/__init__.py b/monst3r/datasets_preprocess/habitat/habitat_renderer/__init__.py new file mode 100644 index 0000000..a326921 --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/habitat_renderer/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/monst3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py b/monst3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py new file mode 100644 index 0000000..4a31f11 --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py @@ -0,0 +1,170 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Render environment maps from 3D meshes using the Habitat Sim simulator. +# -------------------------------------------------------- +import numpy as np +import habitat_sim +import math +from habitat_renderer import projections + +# OpenCV to habitat camera convention transformation +R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0) + +CUBEMAP_FACE_LABELS = ["left", "front", "right", "back", "up", "down"] +# Expressed while considering Habitat coordinates systems +CUBEMAP_FACE_ORIENTATIONS_ROTVEC = [ + [0, math.pi / 2, 0], # Left + [0, 0, 0], # Front + [0, - math.pi / 2, 0], # Right + [0, math.pi, 0], # Back + [math.pi / 2, 0, 0], # Up + [-math.pi / 2, 0, 0],] # Down + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + +class HabitatEnvironmentMapRenderer: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + render_equirectangular=False, + equirectangular_resolution=(512, 1024), + render_cubemap=False, + cubemap_resolution=(512, 512), + render_depth=False, + gpu_id=0): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.gpu_id = gpu_id + + self.render_equirectangular = render_equirectangular + self.equirectangular_resolution = equirectangular_resolution + self.equirectangular_projection = projections.EquirectangularProjection(*equirectangular_resolution) + # 3D unit ray associated to each pixel of the equirectangular map + equirectangular_rays = projections.get_projection_rays(self.equirectangular_projection) + # Not needed, but just in case. + equirectangular_rays /= np.linalg.norm(equirectangular_rays, axis=-1, keepdims=True) + # Depth map created by Habitat are produced by warping a cubemap, + # so the values do not correspond to distance to the center and need some scaling. + self.equirectangular_depth_scale_factors = 1.0 / np.max(np.abs(equirectangular_rays), axis=-1) + + self.render_cubemap = render_cubemap + self.cubemap_resolution = cubemap_resolution + + self.render_depth = render_depth + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly + if self.seed == None: + # Re-seed numpy generator + np.random.seed() + self.seed = np.random.randint(2**32-1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "": + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + sensor_specifications = [] + + # Add cubemaps + if self.render_cubemap: + for face_id, orientation in enumerate(CUBEMAP_FACE_ORIENTATIONS_ROTVEC): + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = f"color_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.cubemap_resolution + rgb_sensor_spec.hfov = 90 + rgb_sensor_spec.position = [0.0, 0.0, 0.0] + rgb_sensor_spec.orientation = orientation + sensor_specifications.append(rgb_sensor_spec) + + if self.render_depth: + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = f"depth_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.cubemap_resolution + depth_sensor_spec.hfov = 90 + depth_sensor_spec.position = [0.0, 0.0, 0.0] + depth_sensor_spec.orientation = orientation + sensor_specifications.append(depth_sensor_spec) + + # Add equirectangular map + if self.render_equirectangular: + rgb_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() + rgb_sensor_spec.uuid = "color_equirectangular" + rgb_sensor_spec.resolution = self.equirectangular_resolution + rgb_sensor_spec.position = [0.0, 0.0, 0.0] + sensor_specifications.append(rgb_sensor_spec) + + if self.render_depth: + depth_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() + depth_sensor_spec.uuid = "depth_equirectangular" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.equirectangular_resolution + depth_sensor_spec.position = [0.0, 0.0, 0.0] + depth_sensor_spec.orientation + sensor_specifications.append(depth_sensor_spec) + + agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=sensor_specifications) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + # Use pre-computed navmesh (the one generated automatically does some weird stuffs like going on top of the roof) + # See https://youtu.be/kunFMRJAu2U?t=1522 regarding navmeshes + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + # Check that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + # Try to compute a navmesh + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + # Check that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})") + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + if hasattr(self, 'sim'): + self.sim.close() + + def __del__(self): + self.close() + + def render_viewpoint(self, viewpoint_position): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + # agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + + try: + # Depth map values have been obtained using cubemap rendering internally, + # so they do not really correspond to distance to the viewpoint in practice + # and they need some scaling + viewpoint_observations["depth_equirectangular"] *= self.equirectangular_depth_scale_factors + except KeyError: + pass + + data = dict(observations=viewpoint_observations, position=viewpoint_position) + return data + + def up_direction(self): + return np.asarray(habitat_sim.geo.UP).tolist() + + def R_cam_to_world(self): + return R_OPENCV2HABITAT.tolist() diff --git a/monst3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py b/monst3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py new file mode 100644 index 0000000..b86238b --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py @@ -0,0 +1,93 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Generate pairs of crops from a dataset of environment maps. +# -------------------------------------------------------- +import os +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 +import collections +from habitat_renderer import projections, projections_conversions +from habitat_renderer.habitat_sim_envmaps_renderer import HabitatEnvironmentMapRenderer + +ViewpointData = collections.namedtuple("ViewpointData", ["colormap", "distancemap", "pointmap", "position"]) + +class HabitatMultiviewCrops: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + equirectangular_resolution=(400, 800), + crop_resolution=(240, 320), + pixel_jittering_iterations=5, + jittering_noise_level=1.0): + self.crop_resolution = crop_resolution + + self.pixel_jittering_iterations = pixel_jittering_iterations + self.jittering_noise_level = jittering_noise_level + + # Instanciate the low resolution habitat sim renderer + self.lowres_envmap_renderer = HabitatEnvironmentMapRenderer(scene=scene, + navmesh=navmesh, + scene_dataset_config_file=scene_dataset_config_file, + equirectangular_resolution=equirectangular_resolution, + render_depth=True, + render_equirectangular=True) + self.R_cam_to_world = np.asarray(self.lowres_envmap_renderer.R_cam_to_world()) + self.up_direction = np.asarray(self.lowres_envmap_renderer.up_direction()) + + # Projection applied by each environment map + self.envmap_height, self.envmap_width = self.lowres_envmap_renderer.equirectangular_resolution + base_projection = projections.EquirectangularProjection(self.envmap_height, self.envmap_width) + self.envmap_projection = projections.RotatedProjection(base_projection, self.R_cam_to_world.T) + # 3D Rays map associated to each envmap + self.envmap_rays = projections.get_projection_rays(self.envmap_projection) + + def compute_pointmap(self, distancemap, position): + # Point cloud associated to each ray + return self.envmap_rays * distancemap[:, :, None] + position + + def render_viewpoint_data(self, position): + data = self.lowres_envmap_renderer.render_viewpoint(np.asarray(position)) + colormap = data['observations']['color_equirectangular'][..., :3] # Ignore the alpha channel + distancemap = data['observations']['depth_equirectangular'] + pointmap = self.compute_pointmap(distancemap, position) + return ViewpointData(colormap=colormap, distancemap=distancemap, pointmap=pointmap, position=position) + + def extract_cropped_camera(self, projection, color_image, distancemap, pointmap, voxelmap=None): + remapper = projections_conversions.RemapProjection(input_projection=self.envmap_projection, output_projection=projection, + pixel_jittering_iterations=self.pixel_jittering_iterations, jittering_noise_level=self.jittering_noise_level) + cropped_color_image = remapper.convert( + color_image, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False) + cropped_distancemap = remapper.convert( + distancemap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True) + cropped_pointmap = remapper.convert(pointmap, interpolation=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_WRAP, single_map=True) + cropped_voxelmap = (None if voxelmap is None else + remapper.convert(voxelmap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True)) + # Convert the distance map into a depth map + cropped_depthmap = np.asarray( + cropped_distancemap / np.linalg.norm(remapper.output_rays, axis=-1), dtype=cropped_distancemap.dtype) + + return cropped_color_image, cropped_depthmap, cropped_pointmap, cropped_voxelmap + +def perspective_projection_to_dict(persp_projection, position): + """ + Serialization-like function.""" + camera_params = dict(camera_intrinsics=projections.colmap_to_opencv_intrinsics(persp_projection.base_projection.K).tolist(), + size=(persp_projection.base_projection.width, persp_projection.base_projection.height), + R_cam2world=persp_projection.R_to_base_projection.T.tolist(), + t_cam2world=position) + return camera_params + + +def dict_to_perspective_projection(camera_params): + K = projections.opencv_to_colmap_intrinsics(np.asarray(camera_params["camera_intrinsics"])) + size = camera_params["size"] + R_cam2world = np.asarray(camera_params["R_cam2world"]) + projection = projections.PerspectiveProjection(K, height=size[1], width=size[0]) + projection = projections.RotatedProjection(projection, R_to_base_projection=R_cam2world.T) + position = camera_params["t_cam2world"] + return projection, position \ No newline at end of file diff --git a/monst3r/datasets_preprocess/habitat/habitat_renderer/projections.py b/monst3r/datasets_preprocess/habitat/habitat_renderer/projections.py new file mode 100644 index 0000000..4db1f79 --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/habitat_renderer/projections.py @@ -0,0 +1,151 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Various 3D/2D projection utils, useful to sample virtual cameras. +# -------------------------------------------------------- +import numpy as np + +class EquirectangularProjection: + """ + Convention for the central pixel of the equirectangular map similar to OpenCV perspective model: + +X from left to right + +Y from top to bottom + +Z going outside the camera + EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)) + """ + + def __init__(self, height, width): + self.height = height + self.width = width + self.u_scaling = (2 * np.pi) / self.width + self.v_scaling = np.pi / self.height + + def unproject(self, u, v): + """ + Args: + u, v: 2D coordinates + Returns: + unnormalized 3D rays. + """ + longitude = self.u_scaling * u - np.pi + minus_latitude = self.v_scaling * v - np.pi/2 + + cos_latitude = np.cos(minus_latitude) + x, z = np.sin(longitude) * cos_latitude, np.cos(longitude) * cos_latitude + y = np.sin(minus_latitude) + + rays = np.stack([x, y, z], axis=-1) + return rays + + def project(self, rays): + """ + Args: + rays: Bx3 array of 3D rays. + Returns: + u, v: tuple of 2D coordinates. + """ + rays = rays / np.linalg.norm(rays, axis=-1, keepdims=True) + x, y, z = [rays[..., i] for i in range(3)] + + longitude = np.arctan2(x, z) + minus_latitude = np.arcsin(y) + + u = (longitude + np.pi) * (1.0 / self.u_scaling) + v = (minus_latitude + np.pi/2) * (1.0 / self.v_scaling) + return u, v + + +class PerspectiveProjection: + """ + OpenCV convention: + World space: + +X from left to right + +Y from top to bottom + +Z going outside the camera + Pixel space: + +u from left to right + +v from top to bottom + EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)). + """ + + def __init__(self, K, height, width): + self.height = height + self.width = width + self.K = K + self.Kinv = np.linalg.inv(K) + + def project(self, rays): + uv_homogeneous = np.einsum("ik, ...k -> ...i", self.K, rays) + uv = uv_homogeneous[..., :2] / uv_homogeneous[..., 2, None] + return uv[..., 0], uv[..., 1] + + def unproject(self, u, v): + uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1) + rays = np.einsum("ik, ...k -> ...i", self.Kinv, uv_homogeneous) + return rays + + +class RotatedProjection: + def __init__(self, base_projection, R_to_base_projection): + self.base_projection = base_projection + self.R_to_base_projection = R_to_base_projection + + @property + def width(self): + return self.base_projection.width + + @property + def height(self): + return self.base_projection.height + + def project(self, rays): + if self.R_to_base_projection is not None: + rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection, rays) + return self.base_projection.project(rays) + + def unproject(self, u, v): + rays = self.base_projection.unproject(u, v) + if self.R_to_base_projection is not None: + rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection.T, rays) + return rays + +def get_projection_rays(projection, noise_level=0): + """ + Return a 2D map of 3D rays corresponding to the projection. + If noise_level > 0, add some jittering noise to these rays. + """ + grid_u, grid_v = np.meshgrid(0.5 + np.arange(projection.width), 0.5 + np.arange(projection.height)) + if noise_level > 0: + grid_u += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_u.shape), projection.width) + grid_v += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_v.shape), projection.height) + return projection.unproject(grid_u, grid_v) + +def compute_camera_intrinsics(height, width, hfov): + f = width/2 / np.tan(hfov/2 * np.pi/180) + cu, cv = width/2, height/2 + return f, cu, cv + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K \ No newline at end of file diff --git a/monst3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py b/monst3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py new file mode 100644 index 0000000..4bcfed4 --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py @@ -0,0 +1,45 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Remap data from one projection to an other +# -------------------------------------------------------- +import numpy as np +import cv2 +from habitat_renderer import projections + +class RemapProjection: + def __init__(self, input_projection, output_projection, pixel_jittering_iterations=0, jittering_noise_level=0): + """ + Some naive random jittering can be introduced in the remapping to mitigate aliasing artecfacts. + """ + assert jittering_noise_level >= 0 + assert pixel_jittering_iterations >= 0 + + maps = [] + # Initial map + self.output_rays = projections.get_projection_rays(output_projection) + map_u, map_v = input_projection.project(self.output_rays) + map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) + maps.append((map_u, map_v)) + + for _ in range(pixel_jittering_iterations): + # Define multiple mappings using some coordinates jittering to mitigate aliasing effects + crop_rays = projections.get_projection_rays(output_projection, jittering_noise_level) + map_u, map_v = input_projection.project(crop_rays) + map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) + maps.append((map_u, map_v)) + self.maps = maps + + def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False): + remapped = [] + for map_u, map_v in self.maps: + res = cv2.remap(img, map_u, map_v, interpolation=interpolation, borderMode=borderMode) + remapped.append(res) + if single_map: + break + if len(remapped) == 1: + res = remapped[0] + else: + res = np.asarray(np.mean(remapped, axis=0), dtype=img.dtype) + return res diff --git a/monst3r/datasets_preprocess/habitat/preprocess_habitat.py b/monst3r/datasets_preprocess/habitat/preprocess_habitat.py new file mode 100644 index 0000000..cacbe24 --- /dev/null +++ b/monst3r/datasets_preprocess/habitat/preprocess_habitat.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# main executable for preprocessing habitat +# export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata" +# export SCENES_DIR="/path/to/habitat/data/scene_datasets/" +# export OUTPUT_DIR="data/habitat_processed" +# export PYTHONPATH=$(pwd) +# python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16 +# -------------------------------------------------------- +import os +import glob +import json +import os + +import PIL.Image +import json +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 +from habitat_renderer import multiview_crop_generator +from tqdm import tqdm + + +def preprocess_metadata(metadata_filename, + scenes_dir, + output_dir, + crop_resolution=[512, 512], + equirectangular_resolution=None, + fix_existing_dataset=False): + # Load data + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + if metadata["scene_dataset_config_file"] == "": + scene = os.path.join(scenes_dir, metadata["scene"]) + scene_dataset_config_file = "" + else: + scene = metadata["scene"] + scene_dataset_config_file = os.path.join(scenes_dir, metadata["scene_dataset_config_file"]) + navmesh = None + + # Use 4 times the crop size as resolution for rendering the environment map. + max_res = max(crop_resolution) + + if equirectangular_resolution == None: + # Use 4 times the crop size as resolution for rendering the environment map. + max_res = max(crop_resolution) + equirectangular_resolution = (4*max_res, 8*max_res) + + print("equirectangular_resolution:", equirectangular_resolution) + + if os.path.exists(output_dir) and not fix_existing_dataset: + raise FileExistsError(output_dir) + + # Lazy initialization + highres_dataset = None + + for batch_label, batch in tqdm(metadata["view_batches"].items()): + for view_label, view_params in batch.items(): + + assert view_params["size"] == crop_resolution + label = f"{batch_label}_{view_label}" + + output_camera_params_filename = os.path.join(output_dir, f"{label}_camera_params.json") + if fix_existing_dataset and os.path.isfile(output_camera_params_filename): + # Skip generation if we are fixing a dataset and the corresponding output file already exists + continue + + # Lazy initialization + if highres_dataset is None: + highres_dataset = multiview_crop_generator.HabitatMultiviewCrops(scene=scene, + navmesh=navmesh, + scene_dataset_config_file=scene_dataset_config_file, + equirectangular_resolution=equirectangular_resolution, + crop_resolution=crop_resolution,) + os.makedirs(output_dir, exist_ok=bool(fix_existing_dataset)) + + # Generate a higher resolution crop + original_projection, position = multiview_crop_generator.dict_to_perspective_projection(view_params) + # Render an envmap at the given position + viewpoint_data = highres_dataset.render_viewpoint_data(position) + + projection = original_projection + colormap, depthmap, pointmap, _ = highres_dataset.extract_cropped_camera( + projection, viewpoint_data.colormap, viewpoint_data.distancemap, viewpoint_data.pointmap) + + camera_params = multiview_crop_generator.perspective_projection_to_dict(projection, position) + + # Color image + PIL.Image.fromarray(colormap).save(os.path.join(output_dir, f"{label}.jpeg")) + # Depth image + cv2.imwrite(os.path.join(output_dir, f"{label}_depth.exr"), + depthmap, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + with open(output_camera_params_filename, "w") as f: + json.dump(camera_params, f) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_dir", required=True) + parser.add_argument("--scenes_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--metadata_filename", default="") + + args = parser.parse_args() + + if args.metadata_filename == "": + # Walk through the metadata dir to generate commandlines + for filename in glob.iglob(os.path.join(args.metadata_dir, "**/metadata.json"), recursive=True): + output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(filename), args.metadata_dir)) + if not os.path.exists(output_dir): + commandline = f"python {__file__} --metadata_filename={filename} --metadata_dir={args.metadata_dir} --scenes_dir={args.scenes_dir} --output_dir={output_dir}" + print(commandline) + else: + preprocess_metadata(metadata_filename=args.metadata_filename, + scenes_dir=args.scenes_dir, + output_dir=args.output_dir) diff --git a/monst3r/datasets_preprocess/path_to_root.py b/monst3r/datasets_preprocess/path_to_root.py new file mode 100644 index 0000000..6e076a1 --- /dev/null +++ b/monst3r/datasets_preprocess/path_to_root.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R repo root import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../')) +# workaround for sibling import +sys.path.insert(0, DUST3R_REPO_PATH) diff --git a/monst3r/datasets_preprocess/prepare_bonn.py b/monst3r/datasets_preprocess/prepare_bonn.py new file mode 100644 index 0000000..3c53834 --- /dev/null +++ b/monst3r/datasets_preprocess/prepare_bonn.py @@ -0,0 +1,40 @@ +# %% +import glob +import os +import shutil +dirs = glob.glob("../data/bonn/rgbd_bonn_dataset/*/") +dirs = sorted(dirs) +# extract frames +for dir in dirs: + frames = glob.glob(dir + 'rgb/*.png') + frames = sorted(frames) + # sample 110 frames at the stride of 2 + frames = frames[30:140] + # cut frames after 110 + new_dir = dir + 'rgb_110/' + + for frame in frames: + os.makedirs(new_dir, exist_ok=True) + shutil.copy(frame, new_dir) + # print(f'cp {frame} {new_dir}') + + depth_frames = glob.glob(dir + 'depth/*.png') + depth_frames = sorted(depth_frames) + # sample 110 frames at the stride of 2 + depth_frames = depth_frames[30:140] + # cut frames after 110 + new_dir = dir + 'depth_110/' + + for frame in depth_frames: + os.makedirs(new_dir, exist_ok=True) + shutil.copy(frame, new_dir) + # print(f'cp {frame} {new_dir}') +import numpy as np +for dir in dirs: + gt_path = "groundtruth.txt" + gt = np.loadtxt(dir + gt_path) + gt_110 = gt[30:140] + np.savetxt(dir + 'groundtruth_110.txt', gt_110) + + + diff --git a/monst3r/datasets_preprocess/prepare_kitti.py b/monst3r/datasets_preprocess/prepare_kitti.py new file mode 100644 index 0000000..cff61af --- /dev/null +++ b/monst3r/datasets_preprocess/prepare_kitti.py @@ -0,0 +1,47 @@ +# %% +#!/usr/bin/python + +from PIL import Image +import numpy as np + + +def depth_read(filename): + # loads depth map D from png file + # and returns it as a numpy array, + # for details see readme.txt + + depth_png = np.array(Image.open(filename), dtype=int) + # make sure we have a proper 16bit depth map here.. not 8bit! + assert(np.max(depth_png) > 255) + + depth = depth_png.astype(np.float) / 256. + depth[depth_png == 0] = -1. + return depth + +# %% +import glob +import os +import shutil +depth_dirs = glob.glob("../data/kitti/val/*/proj_depth/groundtruth/image_02") +for dir in depth_dirs: + # new depth dir + new_depth_dir = "../data/kitti/depth_selection/val_selection_cropped/groundtruth_depth_gathered/" + dir.split("/")[-4]+"_02" + # print(new_depth_dir) + new_image_dir = "../data/kitti/depth_selection/val_selection_cropped/image_gathered/" + dir.split("/")[-4]+"_02" + os.makedirs(new_depth_dir, exist_ok=True) + os.makedirs(new_image_dir, exist_ok=True) + for depth_file in sorted(glob.glob(dir + "/*.png"))[:110]: #../data/kitti/val/2011_09_26_drive_0002_sync/proj_depth/groundtruth/image_02/0000000005.png + new_path = new_depth_dir + "/" + depth_file.split("/")[-1] + shutil.copy(depth_file, new_path) + # get the path of the corresponding image + mid = "_".join(depth_file.split("/")[4].split("_")[:3]) + image_file = depth_file.replace('val', mid).replace('proj_depth/groundtruth/image_02', 'image_02/data') + print(image_file) + # check if the image file exists + if os.path.exists(image_file): + new_path = new_image_dir + "/" + image_file.split("/")[-1] + shutil.copy(image_file, new_path) + else: + print("Image file does not exist: ", image_file) + + diff --git a/monst3r/datasets_preprocess/prepare_nyuv2.py b/monst3r/datasets_preprocess/prepare_nyuv2.py new file mode 100644 index 0000000..87fbac8 --- /dev/null +++ b/monst3r/datasets_preprocess/prepare_nyuv2.py @@ -0,0 +1,84 @@ +# %% +import h5py +import numpy as np +import os +from glob import glob +from PIL import Image + +# Set the path to your dataset directory +dataset_dir = '../data/nyu-v2/val/official/' + +# Get a list of all .h5 files in the dataset directory +file_paths = glob(os.path.join(dataset_dir, '*.h5')) + +# Create output directories for images and depth data +output_image_dir = '../data/nyu-v2/val/nyu_images/' +output_depth_dir = '../data/nyu-v2/val/nyu_depths/' +os.makedirs(output_image_dir, exist_ok=True) +os.makedirs(output_depth_dir, exist_ok=True) + +for file_path in file_paths: + with h5py.File(file_path, 'r') as h5file: + # Read depth and rgb data + depth_data = h5file['depth'][:] + rgb_data = h5file['rgb'][:] + + # Convert rgb data from (3, H, W) to (H, W, 3) + rgb_data = np.transpose(rgb_data, (1, 2, 0)) + + # Ensure that rgb_data is of type uint8 + if rgb_data.dtype != np.uint8: + rgb_data = rgb_data.astype(np.uint8) + + # Get the base filename without extension + base_name = os.path.splitext(os.path.basename(file_path))[0] + + # Save the RGB image as PNG + rgb_image = Image.fromarray(rgb_data) + rgb_image.save(os.path.join(output_image_dir, f'{base_name}.png')) + + # Save the depth data as NPY file + np.save(os.path.join(output_depth_dir, f'{base_name}.npy'), depth_data) + + print(f'Processed {base_name}') + + +# %% +import os +import numpy as np +from PIL import Image + +# Paths +depth_npy_dir = '../data/nyu-v2/val/nyu_depths' +output_img_dir = '../data/nyu-v2/val/nyu_depth_imgs' + +# Ensure the output directory exists +os.makedirs(output_img_dir, exist_ok=True) + +# Iterate over all .npy files in the depth directory +for npy_file in os.listdir(depth_npy_dir): + if npy_file.endswith('.npy'): + # Load depth data from .npy file + depth_path = os.path.join(depth_npy_dir, npy_file) + depth_data = np.load(depth_path) + + # Normalize depth data to range [0, 255] for saving as an image + depth_min = depth_data.min() + depth_max = depth_data.max() + depth_normalized = (depth_data - depth_min) / (depth_max - depth_min) + depth_uint8 = (depth_normalized * 255).astype(np.uint8) + + # Convert to an image + depth_img = Image.fromarray(depth_uint8) + + # Save as PNG file + img_name = os.path.splitext(npy_file)[0] + '.png' + img_save_path = os.path.join(output_img_dir, img_name) + depth_img.save(img_save_path) + + print(f'Saved {img_save_path}') + +print("Conversion completed!") + + + diff --git a/monst3r/datasets_preprocess/prepare_scannet.py b/monst3r/datasets_preprocess/prepare_scannet.py new file mode 100644 index 0000000..5b189b2 --- /dev/null +++ b/monst3r/datasets_preprocess/prepare_scannet.py @@ -0,0 +1,33 @@ +import glob +import os +import shutil +import numpy as np + +seq_list = sorted(os.listdir("../data/scannetv2")) +for seq in seq_list: + img_pathes = sorted(glob.glob(f"../data/scannetv2/{seq}/color/*.jpg"), key=lambda x: int(os.path.basename(x).split('.')[0])) + depth_pathes = sorted(glob.glob(f"../data/scannetv2/{seq}/depth/*.png"), key=lambda x: int(os.path.basename(x).split('.')[0])) + pose_pathes = sorted(glob.glob(f"../data/scannetv2/{seq}/pose/*.txt"), key=lambda x: int(os.path.basename(x).split('.')[0])) + print(f"{seq}: {len(img_pathes)} {len(depth_pathes)}") + + new_color_dir = f"../data/scannetv2/{seq}/color_90" + new_depth_dir = f"../data/scannetv2/{seq}/depth_90" + + new_img_pathes = img_pathes[:90*3:3] + new_depth_pathes = depth_pathes[:90*3:3] + new_pose_pathes = pose_pathes[:90*3:3] + + os.makedirs(new_color_dir, exist_ok=True) + os.makedirs(new_depth_dir, exist_ok=True) + + for i, (img_path, depth_path) in enumerate(zip(new_img_pathes, new_depth_pathes)): + shutil.copy(img_path, f"{new_color_dir}/frame_{i:04d}.jpg") + shutil.copy(depth_path, f"{new_depth_dir}/frame_{i:04d}.png") + + pose_new_path = f"../data/scannetv2/{seq}/pose_90.txt" + with open(pose_new_path, 'w') as f: + for i, pose_path in enumerate(new_pose_pathes): + with open(pose_path, 'r') as pose_file: + pose = np.loadtxt(pose_file) + pose = pose.reshape(-1) + f.write(f"{' '.join(map(str, pose))}\n") diff --git a/monst3r/datasets_preprocess/prepare_tum.py b/monst3r/datasets_preprocess/prepare_tum.py new file mode 100644 index 0000000..de1654c --- /dev/null +++ b/monst3r/datasets_preprocess/prepare_tum.py @@ -0,0 +1,93 @@ +import glob +import os +import shutil +import numpy as np + +def read_file_list(filename): + """ + Reads a trajectory from a text file. + + File format: + The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched) + and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp. + + Input: + filename -- File name + + Output: + dict -- dictionary of (stamp,data) tuples + + """ + file = open(filename) + data = file.read() + lines = data.replace(","," ").replace("\t"," ").split("\n") + list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"] + list = [(float(l[0]),l[1:]) for l in list if len(l)>1] + return dict(list) + +def associate(first_list, second_list, offset, max_difference): + """ + Associate two dictionaries of (stamp, data). As the time stamps never match exactly, we aim + to find the closest match for every input tuple. + + Input: + first_list -- first dictionary of (stamp, data) tuples + second_list -- second dictionary of (stamp, data) tuples + offset -- time offset between both dictionaries (e.g., to model the delay between the sensors) + max_difference -- search radius for candidate generation + + Output: + matches -- list of matched tuples ((stamp1, data1), (stamp2, data2)) + """ + # Convert keys to sets for efficient removal + first_keys = set(first_list.keys()) + second_keys = set(second_list.keys()) + + potential_matches = [(abs(a - (b + offset)), a, b) + for a in first_keys + for b in second_keys + if abs(a - (b + offset)) < max_difference] + potential_matches.sort() + matches = [] + for diff, a, b in potential_matches: + if a in first_keys and b in second_keys: + first_keys.remove(a) + second_keys.remove(b) + matches.append((a, b)) + + matches.sort() + return matches + +dirs = glob.glob("../data/tum/*/") +dirs = sorted(dirs) +# extract frames +for dir in dirs: + frames = [] + gt = [] + first_file = dir + 'rgb.txt' + second_file = dir + 'groundtruth.txt' + + first_list = read_file_list(first_file) + second_list = read_file_list(second_file) + matches = associate(first_list, second_list, 0.0, 0.02) + + # for a,b in matches[:10]: + # print("%f %s %f %s"%(a," ".join(first_list[a]),b," ".join(second_list[b]))) + for a,b in matches: + frames.append(dir + first_list[a][0]) + gt.append([b]+second_list[b]) + + # sample 90 frames at the stride of 3 + frames = frames[::3][:90] + # cut frames after 90 + new_dir = dir + 'rgb_90/' + + for frame in frames: + os.makedirs(new_dir, exist_ok=True) + shutil.copy(frame, new_dir) + # print(f'cp {frame} {new_dir}') + + gt_90 = gt[::3][:90] + with open(dir + 'groundtruth_90.txt', 'w') as f: + for pose in gt_90: + f.write(f"{' '.join(map(str, pose))}\n") \ No newline at end of file diff --git a/monst3r/datasets_preprocess/preprocess_arkitscenes.py b/monst3r/datasets_preprocess/preprocess_arkitscenes.py new file mode 100644 index 0000000..5dbc103 --- /dev/null +++ b/monst3r/datasets_preprocess/preprocess_arkitscenes.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the arkitscenes dataset. +# Usage: +# python3 datasets_preprocess/preprocess_arkitscenes.py --arkitscenes_dir /path/to/arkitscenes --precomputed_pairs /path/to/arkitscenes_pairs +# -------------------------------------------------------- +import os +import json +import os.path as osp +import decimal +import argparse +import math +from bisect import bisect_left +from PIL import Image +import numpy as np +import quaternion +from scipy import interpolate +import cv2 + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--arkitscenes_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/arkitscenes_processed') + return parser + + +def value_to_decimal(value, decimal_places): + decimal.getcontext().rounding = decimal.ROUND_HALF_UP # define rounding method + return decimal.Decimal(str(float(value))).quantize(decimal.Decimal('1e-{}'.format(decimal_places))) + + +def closest(value, sorted_list): + index = bisect_left(sorted_list, value) + if index == 0: + return sorted_list[0] + elif index == len(sorted_list): + return sorted_list[-1] + else: + value_before = sorted_list[index - 1] + value_after = sorted_list[index] + if value_after - value < value - value_before: + return value_after + else: + return value_before + + +def get_up_vectors(pose_device_to_world): + return np.matmul(pose_device_to_world, np.array([[0.0], [-1.0], [0.0], [0.0]])) + + +def get_right_vectors(pose_device_to_world): + return np.matmul(pose_device_to_world, np.array([[1.0], [0.0], [0.0], [0.0]])) + + +def read_traj(traj_path): + quaternions = [] + poses = [] + timestamps = [] + poses_p_to_w = [] + with open(traj_path) as f: + traj_lines = f.readlines() + for line in traj_lines: + tokens = line.split() + assert len(tokens) == 7 + traj_timestamp = float(tokens[0]) + + timestamps_decimal_value = value_to_decimal(traj_timestamp, 3) + timestamps.append(float(timestamps_decimal_value)) # for spline interpolation + + angle_axis = [float(tokens[1]), float(tokens[2]), float(tokens[3])] + r_w_to_p, _ = cv2.Rodrigues(np.asarray(angle_axis)) + t_w_to_p = np.asarray([float(tokens[4]), float(tokens[5]), float(tokens[6])]) + + pose_w_to_p = np.eye(4) + pose_w_to_p[:3, :3] = r_w_to_p + pose_w_to_p[:3, 3] = t_w_to_p + + pose_p_to_w = np.linalg.inv(pose_w_to_p) + + r_p_to_w_as_quat = quaternion.from_rotation_matrix(pose_p_to_w[:3, :3]) + t_p_to_w = pose_p_to_w[:3, 3] + poses_p_to_w.append(pose_p_to_w) + poses.append(t_p_to_w) + quaternions.append(r_p_to_w_as_quat) + return timestamps, poses, quaternions, poses_p_to_w + + +def main(rootdir, pairsdir, outdir): + os.makedirs(outdir, exist_ok=True) + + subdirs = ['Test', 'Training'] + for subdir in subdirs: + if not osp.isdir(osp.join(rootdir, subdir)): + continue + # STEP 1: list all scenes + outsubdir = osp.join(outdir, subdir) + os.makedirs(outsubdir, exist_ok=True) + listfile = osp.join(pairsdir, subdir, 'scene_list.json') + with open(listfile, 'r') as f: + scene_dirs = json.load(f) + + valid_scenes = [] + for scene_subdir in scene_dirs: + out_scene_subdir = osp.join(outsubdir, scene_subdir) + os.makedirs(out_scene_subdir, exist_ok=True) + + scene_dir = osp.join(rootdir, subdir, scene_subdir) + depth_dir = osp.join(scene_dir, 'lowres_depth') + rgb_dir = osp.join(scene_dir, 'vga_wide') + intrinsics_dir = osp.join(scene_dir, 'vga_wide_intrinsics') + traj_path = osp.join(scene_dir, 'lowres_wide.traj') + + # STEP 2: read selected_pairs.npz + selected_pairs_path = osp.join(pairsdir, subdir, scene_subdir, 'selected_pairs.npz') + selected_npz = np.load(selected_pairs_path) + selection, pairs = selected_npz['selection'], selected_npz['pairs'] + selected_sky_direction_scene = str(selected_npz['sky_direction_scene'][0]) + if len(selection) == 0 or len(pairs) == 0: + # not a valid scene + continue + valid_scenes.append(scene_subdir) + + # STEP 3: parse the scene and export the list of valid (K, pose, rgb, depth) and convert images + scene_metadata_path = osp.join(out_scene_subdir, 'scene_metadata.npz') + if osp.isfile(scene_metadata_path): + continue + else: + print(f'parsing {scene_subdir}') + # loads traj + timestamps, poses, quaternions, poses_cam_to_world = read_traj(traj_path) + + poses = np.array(poses) + quaternions = np.array(quaternions, dtype=np.quaternion) + quaternions = quaternion.unflip_rotors(quaternions) + timestamps = np.array(timestamps) + + selected_images = [(basename, basename.split(".png")[0].split("_")[1]) for basename in selection] + timestamps_selected = [float(frame_id) for _, frame_id in selected_images] + + sky_direction_scene, trajectories, intrinsics, images = convert_scene_metadata(scene_subdir, + intrinsics_dir, + timestamps, + quaternions, + poses, + poses_cam_to_world, + selected_images, + timestamps_selected) + assert selected_sky_direction_scene == sky_direction_scene + + os.makedirs(os.path.join(out_scene_subdir, 'vga_wide'), exist_ok=True) + os.makedirs(os.path.join(out_scene_subdir, 'lowres_depth'), exist_ok=True) + assert isinstance(sky_direction_scene, str) + for basename in images: + img_out = os.path.join(out_scene_subdir, 'vga_wide', basename.replace('.png', '.jpg')) + depth_out = os.path.join(out_scene_subdir, 'lowres_depth', basename) + if osp.isfile(img_out) and osp.isfile(depth_out): + continue + + vga_wide_path = osp.join(rgb_dir, basename) + depth_path = osp.join(depth_dir, basename) + + img = Image.open(vga_wide_path) + depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) + + # rotate the image + if sky_direction_scene == 'RIGHT': + try: + img = img.transpose(Image.Transpose.ROTATE_90) + except Exception: + img = img.transpose(Image.ROTATE_90) + depth = cv2.rotate(depth, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif sky_direction_scene == 'LEFT': + try: + img = img.transpose(Image.Transpose.ROTATE_270) + except Exception: + img = img.transpose(Image.ROTATE_270) + depth = cv2.rotate(depth, cv2.ROTATE_90_CLOCKWISE) + elif sky_direction_scene == 'DOWN': + try: + img = img.transpose(Image.Transpose.ROTATE_180) + except Exception: + img = img.transpose(Image.ROTATE_180) + depth = cv2.rotate(depth, cv2.ROTATE_180) + + W, H = img.size + if not osp.isfile(img_out): + img.save(img_out) + + depth = cv2.resize(depth, (W, H), interpolation=cv2.INTER_NEAREST_EXACT) + if not osp.isfile(depth_out): # avoid destroying the base dataset when you mess up the paths + cv2.imwrite(depth_out, depth) + + # save at the end + np.savez(scene_metadata_path, + trajectories=trajectories, + intrinsics=intrinsics, + images=images, + pairs=pairs) + + outlistfile = osp.join(outsubdir, 'scene_list.json') + with open(outlistfile, 'w') as f: + json.dump(valid_scenes, f) + + # STEP 5: concat all scene_metadata.npz into a single file + scene_data = {} + for scene_subdir in valid_scenes: + scene_metadata_path = osp.join(outsubdir, scene_subdir, 'scene_metadata.npz') + with np.load(scene_metadata_path) as data: + trajectories = data['trajectories'] + intrinsics = data['intrinsics'] + images = data['images'] + pairs = data['pairs'] + scene_data[scene_subdir] = {'trajectories': trajectories, + 'intrinsics': intrinsics, + 'images': images, + 'pairs': pairs} + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + pairs = [] + for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): + num_imgs = data['images'].shape[0] + img_pairs = data['pairs'] + + scenes.append(scene_subdir) + sceneids.extend([scene_idx] * num_imgs) + + images.append(data['images']) + + K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) + K[:, 0, 0] = [fx for _, _, fx, _, _, _ in data['intrinsics']] + K[:, 1, 1] = [fy for _, _, _, fy, _, _ in data['intrinsics']] + K[:, 0, 2] = [hw for _, _, _, _, hw, _ in data['intrinsics']] + K[:, 1, 2] = [hh for _, _, _, _, _, hh in data['intrinsics']] + + intrinsics.append(K) + trajectories.append(data['trajectories']) + + # offset pairs + img_pairs[:, 0:2] += offset + pairs.append(img_pairs) + counts.append(offset) + + offset += num_imgs + + images = np.concatenate(images, axis=0) + intrinsics = np.concatenate(intrinsics, axis=0) + trajectories = np.concatenate(trajectories, axis=0) + pairs = np.concatenate(pairs, axis=0) + np.savez(osp.join(outsubdir, 'all_metadata.npz'), + counts=counts, + scenes=scenes, + sceneids=sceneids, + images=images, + intrinsics=intrinsics, + trajectories=trajectories, + pairs=pairs) + + +def convert_scene_metadata(scene_subdir, intrinsics_dir, + timestamps, quaternions, poses, poses_cam_to_world, + selected_images, timestamps_selected): + # find scene orientation + sky_direction_scene, rotated_to_cam = find_scene_orientation(poses_cam_to_world) + + # find/compute pose for selected timestamps + # most images have a valid timestamp / exact pose associated + timestamps_selected = np.array(timestamps_selected) + spline = interpolate.interp1d(timestamps, poses, kind='linear', axis=0) + interpolated_rotations = quaternion.squad(quaternions, timestamps, timestamps_selected) + interpolated_positions = spline(timestamps_selected) + + trajectories = [] + intrinsics = [] + images = [] + for i, (basename, frame_id) in enumerate(selected_images): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{frame_id}.pincam") + if not osp.exists(intrinsic_fn): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) - 0.001:.3f}.pincam") + if not osp.exists(intrinsic_fn): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) + 0.001:.3f}.pincam") + assert osp.exists(intrinsic_fn) + w, h, fx, fy, hw, hh = np.loadtxt(intrinsic_fn) # PINHOLE + + pose = np.eye(4) + pose[:3, :3] = quaternion.as_rotation_matrix(interpolated_rotations[i]) + pose[:3, 3] = interpolated_positions[i] + + images.append(basename) + if sky_direction_scene == 'RIGHT' or sky_direction_scene == 'LEFT': + intrinsics.append([h, w, fy, fx, hh, hw]) # swapped intrinsics + else: + intrinsics.append([w, h, fx, fy, hw, hh]) + trajectories.append(pose @ rotated_to_cam) # pose_cam_to_world @ rotated_to_cam = rotated(cam) to world + + return sky_direction_scene, trajectories, intrinsics, images + + +def find_scene_orientation(poses_cam_to_world): + if len(poses_cam_to_world) > 0: + up_vector = sum(get_up_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) + right_vector = sum(get_right_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) + up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) + else: + up_vector = np.array([[0.0], [-1.0], [0.0], [0.0]]) + right_vector = np.array([[1.0], [0.0], [0.0], [0.0]]) + up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) + + # value between 0, 180 + device_up_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), + up_vector), -1.0, 1.0)).item() * 180.0 / np.pi + device_right_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), + right_vector), -1.0, 1.0)).item() * 180.0 / np.pi + + up_closest_to_90 = abs(device_up_to_world_up_angle - 90.0) < abs(device_right_to_world_up_angle - 90.0) + if up_closest_to_90: + assert abs(device_up_to_world_up_angle - 90.0) < 45.0 + # LEFT + if device_right_to_world_up_angle > 90.0: + sky_direction_scene = 'LEFT' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi / 2.0]) + else: + # note that in metadata.csv RIGHT does not exist, but again it's not accurate... + # well, turns out there are scenes oriented like this + # for example Training/41124801 + sky_direction_scene = 'RIGHT' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, -math.pi / 2.0]) + else: + # right is close to 90 + assert abs(device_right_to_world_up_angle - 90.0) < 45.0 + if device_up_to_world_up_angle > 90.0: + sky_direction_scene = 'DOWN' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi]) + else: + sky_direction_scene = 'UP' + cam_to_rotated_q = quaternion.quaternion(1, 0, 0, 0) + cam_to_rotated = np.eye(4) + cam_to_rotated[:3, :3] = quaternion.as_rotation_matrix(cam_to_rotated_q) + rotated_to_cam = np.linalg.inv(cam_to_rotated) + return sky_direction_scene, rotated_to_cam + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.arkitscenes_dir, args.precomputed_pairs, args.output_dir) diff --git a/monst3r/datasets_preprocess/preprocess_blendedMVS.py b/monst3r/datasets_preprocess/preprocess_blendedMVS.py new file mode 100644 index 0000000..d227937 --- /dev/null +++ b/monst3r/datasets_preprocess/preprocess_blendedMVS.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the BlendedMVS dataset +# dataset at https://github.com/YoYo000/BlendedMVS +# 1) Download BlendedMVS.zip +# 2) Download BlendedMVS+.zip +# 3) Download BlendedMVS++.zip +# 4) Unzip everything in the same /path/to/tmp/blendedMVS/ directory +# 5) python datasets_preprocess/preprocess_blendedMVS.py --blendedmvs_dir /path/to/tmp/blendedMVS/ +# -------------------------------------------------------- +import os +import os.path as osp +import re +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--blendedmvs_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/blendedmvs_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + print('>> Listing all sequences') + sequences = [f for f in os.listdir(db_root) if len(f) == 24] + # should find 502 scenes + assert sequences, f'did not found any sequences at {db_root}' + print(f' (found {len(sequences)} sequences)') + + for i, seq in enumerate(tqdm(sequences)): + out_dir = osp.join(output_dir, seq) + os.makedirs(out_dir, exist_ok=True) + + # generate the crops + root = osp.join(db_root, seq) + cam_dir = osp.join(root, 'cams') + func_args = [(root, f[:-8], out_dir) for f in os.listdir(cam_dir) if not f.startswith('pair')] + parallel_threads(load_crop_and_save, func_args, star_args=True, leave=False) + + # verify that all pairs are there + pairs = np.load(pairs_path) + for seqh, seql, img1, img2, score in tqdm(pairs): + for view_index in [img1, img2]: + impath = osp.join(output_dir, f"{seqh:08x}{seql:016x}", f"{view_index:08n}.jpg") + assert osp.isfile(impath), f'missing image at {impath=}' + + print(f'>> Done, saved everything in {output_dir}/') + + +def load_crop_and_save(root, img, out_dir): + if osp.isfile(osp.join(out_dir, img + '.npz')): + return # already done + + # load everything + intrinsics_in, R_camin2world, t_camin2world = _load_pose(osp.join(root, 'cams', img + '_cam.txt')) + color_image_in = cv2.cvtColor(cv2.imread(osp.join(root, 'blended_images', img + + '.jpg'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + depthmap_in = load_pfm_file(osp.join(root, 'rendered_depth_maps', img + '.pfm')) + + # do the crop + H, W = color_image_in.shape[:2] + assert H * 4 == W * 3 + image, depthmap, intrinsics_out, R_in2out = _crop_image(intrinsics_in, color_image_in, depthmap_in, (512, 384)) + + # write everything + image.save(osp.join(out_dir, img + '.jpg'), quality=80) + cv2.imwrite(osp.join(out_dir, img + '.exr'), depthmap) + + # New camera parameters + R_camout2world = R_camin2world @ R_in2out.T + t_camout2world = t_camin2world + np.savez(osp.join(out_dir, img + '.npz'), intrinsics=intrinsics_out, + R_cam2world=R_camout2world, t_cam2world=t_camout2world) + + +def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(800, 800)): + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + color_image_in, depthmap_in, intrinsics_in, resolution_out) + R_in2out = np.eye(3) + return image, depthmap, intrinsics_out, R_in2out + + +def _load_pose(path, ret_44=False): + f = open(path) + RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32) + assert RT.shape == (4, 4) + RT = np.linalg.inv(RT) # world2cam to cam2world + + K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32) + assert K.shape == (3, 3) + + if ret_44: + return K, RT + return K, RT[:3, :3], RT[:3, 3] # , depth_uint8_to_f32 + + +def load_pfm_file(file_path): + with open(file_path, 'rb') as file: + header = file.readline().decode('UTF-8').strip() + + if header == 'PF': + is_color = True + elif header == 'Pf': + is_color = False + else: + raise ValueError('The provided file is not a valid PFM file.') + + dimensions = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('UTF-8')) + if dimensions: + img_width, img_height = map(int, dimensions.groups()) + else: + raise ValueError('Invalid PFM header format.') + + endian_scale = float(file.readline().decode('UTF-8').strip()) + if endian_scale < 0: + dtype = '= img_size * 3/4, and max dimension will be >= img_size")) + return parser + + +def convert_ndc_to_pinhole(focal_length, principal_point, image_size): + focal_length = np.array(focal_length) + principal_point = np.array(principal_point) + image_size_wh = np.array([image_size[1], image_size[0]]) + half_image_size = image_size_wh / 2 + rescale = half_image_size.min() + principal_point_px = half_image_size - principal_point * rescale + focal_length_px = focal_length * rescale + fx, fy = focal_length_px[0], focal_length_px[1] + cx, cy = principal_point_px[0], principal_point_px[1] + K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32) + return K + + +def opencv_from_cameras_projection(R, T, focal, p0, image_size): + R = torch.from_numpy(R)[None, :, :] + T = torch.from_numpy(T)[None, :] + focal = torch.from_numpy(focal)[None, :] + p0 = torch.from_numpy(p0)[None, :] + image_size = torch.from_numpy(image_size)[None, :] + + R_pytorch3d = R.clone() + T_pytorch3d = T.clone() + focal_pytorch3d = focal + p0_pytorch3d = p0 + T_pytorch3d[:, :2] *= -1 + R_pytorch3d[:, :, :2] *= -1 + tvec = T_pytorch3d + R = R_pytorch3d.permute(0, 2, 1) + + # Retype the image_size correctly and flip to width, height. + image_size_wh = image_size.to(R).flip(dims=(1,)) + + # NDC to screen conversion. + scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 + scale = scale.expand(-1, 2) + c0 = image_size_wh / 2.0 + + principal_point = -p0_pytorch3d * scale + c0 + focal_length = focal_pytorch3d * scale + + camera_matrix = torch.zeros_like(R) + camera_matrix[:, :2, 2] = principal_point + camera_matrix[:, 2, 2] = 1.0 + camera_matrix[:, 0, 0] = focal_length[:, 0] + camera_matrix[:, 1, 1] = focal_length[:, 1] + return R[0], tvec[0], camera_matrix[0] + + +def get_set_list(category_dir, split, is_single_sequence_subset=False): + listfiles = os.listdir(osp.join(category_dir, "set_lists")) + if is_single_sequence_subset: + # not all objects have manyview_dev + subset_list_files = [f for f in listfiles if "manyview_dev" in f] + else: + subset_list_files = [f for f in listfiles if f"fewview_train" in f] + + sequences_all = [] + for subset_list_file in subset_list_files: + with open(osp.join(category_dir, "set_lists", subset_list_file)) as f: + subset_lists_data = json.load(f) + sequences_all.extend(subset_lists_data[split]) + + return sequences_all + + +def prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object, + seed, is_single_sequence_subset=False): + random.seed(seed) + category_dir = osp.join(co3d_dir, category) + category_output_dir = osp.join(output_dir, category) + sequences_all = get_set_list(category_dir, split, is_single_sequence_subset) + sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all)) + + frame_file = osp.join(category_dir, "frame_annotations.jgz") + sequence_file = osp.join(category_dir, "sequence_annotations.jgz") + + with gzip.open(frame_file, "r") as fin: + frame_data = json.loads(fin.read()) + with gzip.open(sequence_file, "r") as fin: + sequence_data = json.loads(fin.read()) + + frame_data_processed = {} + for f_data in frame_data: + sequence_name = f_data["sequence_name"] + frame_data_processed.setdefault(sequence_name, {})[f_data["frame_number"]] = f_data + + good_quality_sequences = set() + for seq_data in sequence_data: + if seq_data["viewpoint_quality_score"] > min_quality: + good_quality_sequences.add(seq_data["sequence_name"]) + + sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences] + if len(sequences_numbers) < max_num_sequences_per_object: + selected_sequences_numbers = sequences_numbers + else: + selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object) + + selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers} + sequences_all = [(seq_name, frame_number, filepath) + for seq_name, frame_number, filepath in sequences_all + if seq_name in selected_sequences_numbers_dict] + + for seq_name, frame_number, filepath in tqdm(sequences_all): + frame_idx = int(filepath.split('/')[-1][5:-4]) + selected_sequences_numbers_dict[seq_name].append(frame_idx) + mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") + frame_data = frame_data_processed[seq_name][frame_number] + focal_length = frame_data["viewpoint"]["focal_length"] + principal_point = frame_data["viewpoint"]["principal_point"] + image_size = frame_data["image"]["size"] + K = convert_ndc_to_pinhole(focal_length, principal_point, image_size) + R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data["viewpoint"]["R"]), + np.array(frame_data["viewpoint"]["T"]), + np.array(focal_length), + np.array(principal_point), + np.array(image_size)) + + frame_data = frame_data_processed[seq_name][frame_number] + depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"]) + assert frame_data["depth"]["scale_adjustment"] == 1.0 + image_path = os.path.join(co3d_dir, filepath) + mask_path_full = os.path.join(co3d_dir, mask_path) + + input_rgb_image = PIL.Image.open(image_path).convert('RGB') + input_mask = plt.imread(mask_path_full) + + with PIL.Image.open(depth_path) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + input_depthmap = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0]))) + depth_mask = np.stack((input_depthmap, input_mask), axis=-1) + H, W = input_depthmap.shape + + camera_intrinsics = camera_intrinsics.numpy() + cx, cy = camera_intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( + input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) + + # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 + scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + if max(output_resolution) < img_size: + # let's put the max dimension to img_size + scale_final = (img_size / max(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( + input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) + input_depthmap = depth_mask[:, :, 0] + input_mask = depth_mask[:, :, 1] + + # generate and adjust camera pose + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = R + camera_pose[:3, 3] = tvec + camera_pose = np.linalg.inv(camera_pose) + + # save crop images and depth, metadata + save_img_path = os.path.join(output_dir, filepath) + save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"]) + save_mask_path = os.path.join(output_dir, mask_path) + os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) + + input_rgb_image.save(save_img_path) + scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16) + cv2.imwrite(save_depth_path, scaled_depth_map) + cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) + + save_meta_path = save_img_path.replace('jpg', 'npz') + np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, + camera_pose=camera_pose, maximum_depth=np.max(input_depthmap)) + + return selected_sequences_numbers_dict + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + assert args.co3d_dir != args.output_dir + if args.category is None: + if args.single_sequence_subset: + categories = SINGLE_SEQUENCE_CATEGORIES + else: + categories = CATEGORIES + else: + categories = [args.category] + os.makedirs(args.output_dir, exist_ok=True) + + for split in ['train', 'test']: + selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(selected_sequences_path): + continue + + all_selected_sequences = {} + for category in categories: + category_output_dir = osp.join(args.output_dir, category) + os.makedirs(category_output_dir, exist_ok=True) + category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(category_selected_sequences_path): + with open(category_selected_sequences_path, 'r') as fid: + category_selected_sequences = json.load(fid) + else: + print(f"Processing {split} - category = {category}") + category_selected_sequences = prepare_sequences( + category=category, + co3d_dir=args.co3d_dir, + output_dir=args.output_dir, + img_size=args.img_size, + split=split, + min_quality=args.min_quality, + max_num_sequences_per_object=args.num_sequences_per_object, + seed=args.seed + CATEGORIES_IDX[category], + is_single_sequence_subset=args.single_sequence_subset + ) + with open(category_selected_sequences_path, 'w') as file: + json.dump(category_selected_sequences, file) + + all_selected_sequences[category] = category_selected_sequences + with open(selected_sequences_path, 'w') as file: + json.dump(all_selected_sequences, file) diff --git a/monst3r/datasets_preprocess/preprocess_megadepth.py b/monst3r/datasets_preprocess/preprocess_megadepth.py new file mode 100644 index 0000000..b07c0c5 --- /dev/null +++ b/monst3r/datasets_preprocess/preprocess_megadepth.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the MegaDepth dataset +# dataset at https://www.cs.cornell.edu/projects/megadepth/ +# -------------------------------------------------------- +import os +import os.path as osp +import collections +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 +import h5py + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--megadepth_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/megadepth_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + os.makedirs(output_dir, exist_ok=True) + + # load all pairs + data = np.load(pairs_path, allow_pickle=True) + scenes = data['scenes'] + images = data['images'] + pairs = data['pairs'] + + # enumerate all unique images + todo = collections.defaultdict(set) + for scene, im1, im2, score in pairs: + todo[scene].add(im1) + todo[scene].add(im2) + + # for each scene, load intrinsics and then parallel crops + for scene, im_idxs in tqdm(todo.items(), desc='Overall'): + scene, subscene = scenes[scene].split() + out_dir = osp.join(output_dir, scene, subscene) + os.makedirs(out_dir, exist_ok=True) + + # load all camera params + _, pose_w2cam, intrinsics = _load_kpts_and_poses(db_root, scene, subscene, intrinsics=True) + + in_dir = osp.join(db_root, scene, 'dense' + subscene) + args = [(in_dir, img, intrinsics[img], pose_w2cam[img], out_dir) + for img in [images[im_id] for im_id in im_idxs]] + parallel_threads(resize_one_image, args, star_args=True, front_num=0, leave=False, desc=f'{scene}/{subscene}') + + # save pairs + print('Done! prepared all pairs in', output_dir) + + +def resize_one_image(root, tag, K_pre_rectif, pose_w2cam, out_dir): + if osp.isfile(osp.join(out_dir, tag + '.npz')): + return + + # load image + img = cv2.cvtColor(cv2.imread(osp.join(root, 'imgs', tag), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + H, W = img.shape[:2] + + # load depth + with h5py.File(osp.join(root, 'depths', osp.splitext(tag)[0] + '.h5'), 'r') as hd5: + depthmap = np.asarray(hd5['depth']) + + # rectify = undistort the intrinsics + imsize_pre, K_pre, distortion = K_pre_rectif + imsize_post = img.shape[1::-1] + K_post = cv2.getOptimalNewCameraMatrix(K_pre, distortion, imsize_pre, alpha=0, + newImgSize=imsize_post, centerPrincipalPoint=True)[0] + + # downscale + img_out, depthmap_out, intrinsics_out, R_in2out = _downscale_image(K_post, img, depthmap, resolution_out=(800, 600)) + + # write everything + img_out.save(osp.join(out_dir, tag + '.jpg'), quality=90) + cv2.imwrite(osp.join(out_dir, tag + '.exr'), depthmap_out) + + camout2world = np.linalg.inv(pose_w2cam) + camout2world[:3, :3] = camout2world[:3, :3] @ R_in2out.T + np.savez(osp.join(out_dir, tag + '.npz'), intrinsics=intrinsics_out, cam2world=camout2world) + + +def _downscale_image(camera_intrinsics, image, depthmap, resolution_out=(512, 384)): + H, W = image.shape[:2] + resolution_out = sorted(resolution_out)[::+1 if W < H else -1] + + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + image, depthmap, camera_intrinsics, resolution_out, force=False) + R_in2out = np.eye(3) + + return image, depthmap, intrinsics_out, R_in2out + + +def _load_kpts_and_poses(root, scene_id, subscene, z_only=False, intrinsics=False): + if intrinsics: + with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'cameras.txt'), 'r') as f: + raw = f.readlines()[3:] # skip the header + + camera_intrinsics = {} + for camera in raw: + camera = camera.split(' ') + width, height, focal, cx, cy, k0 = [float(elem) for elem in camera[2:]] + K = np.eye(3) + K[0, 0] = focal + K[1, 1] = focal + K[0, 2] = cx + K[1, 2] = cy + camera_intrinsics[int(camera[0])] = ((int(width), int(height)), K, (k0, 0, 0, 0)) + + with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'images.txt'), 'r') as f: + raw = f.read().splitlines()[4:] # skip the header + + extract_pose = colmap_raw_pose_to_principal_axis if z_only else colmap_raw_pose_to_RT + + poses = {} + points3D_idxs = {} + camera = [] + + for image, points in zip(raw[:: 2], raw[1:: 2]): + image = image.split(' ') + points = points.split(' ') + + image_id = image[-1] + camera.append(int(image[-2])) + + # find the principal axis + raw_pose = [float(elem) for elem in image[1: -2]] + poses[image_id] = extract_pose(raw_pose) + + current_points3D_idxs = {int(i) for i in points[2:: 3] if i != '-1'} + assert -1 not in current_points3D_idxs, bb() + points3D_idxs[image_id] = current_points3D_idxs + + if intrinsics: + image_intrinsics = {im_id: camera_intrinsics[cam] for im_id, cam in zip(poses, camera)} + return points3D_idxs, poses, image_intrinsics + else: + return points3D_idxs, poses + + +def colmap_raw_pose_to_principal_axis(image_pose): + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + z_axis = np.float32([ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ]) + return z_axis + + +def colmap_raw_pose_to_RT(image_pose): + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array([ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ] + ]) + # principal_axis.append(R[2, :]) + t = image_pose[4: 7] + # World-to-Camera pose + current_pose = np.eye(4) + current_pose[: 3, : 3] = R + current_pose[: 3, 3] = t + return current_pose + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.megadepth_dir, args.precomputed_pairs, args.output_dir) diff --git a/monst3r/datasets_preprocess/preprocess_scannetpp.py b/monst3r/datasets_preprocess/preprocess_scannetpp.py new file mode 100644 index 0000000..34e26dc --- /dev/null +++ b/monst3r/datasets_preprocess/preprocess_scannetpp.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the scannet++ dataset. +# Usage: +# python3 datasets_preprocess/preprocess_scannetpp.py --scannetpp_dir /path/to/scannetpp --precomputed_pairs /path/to/scannetpp_pairs --pyopengl-platform egl +# -------------------------------------------------------- +import os +import argparse +import os.path as osp +import re +from tqdm import tqdm +import json +from scipy.spatial.transform import Rotation +import pyrender +import trimesh +import trimesh.exchange.ply +import numpy as np +import cv2 +import PIL.Image as Image + +from dust3r.datasets.utils.cropping import rescale_image_depthmap +import dust3r.utils.geometry as geometry + +inv = np.linalg.inv +norm = np.linalg.norm +REGEXPR_DSLR = re.compile(r'^DSC(?P\d+).JPG$') +REGEXPR_IPHONE = re.compile(r'frame_(?P\d+).jpg$') + +DEBUG_VIZ = None # 'iou' +if DEBUG_VIZ is not None: + import matplotlib.pyplot as plt # noqa + + +OPENGL_TO_OPENCV = np.float32([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--scannetpp_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/scannetpp_processed') + parser.add_argument('--target_resolution', default=920, type=int, help="images resolution") + parser.add_argument('--pyopengl-platform', type=str, default='', help='PyOpenGL env variable') + return parser + + +def pose_from_qwxyz_txyz(elems): + qw, qx, qy, qz, tx, ty, tz = map(float, elems) + pose = np.eye(4) + pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() + pose[:3, 3] = (tx, ty, tz) + return np.linalg.inv(pose) # returns cam2world + + +def get_frame_number(name, cam_type='dslr'): + if cam_type == 'dslr': + regex_expr = REGEXPR_DSLR + elif cam_type == 'iphone': + regex_expr = REGEXPR_IPHONE + else: + raise NotImplementedError(f'wrong {cam_type=} for get_frame_number') + matches = re.match(regex_expr, name) + return matches['frameid'] + + +def load_sfm(sfm_dir, cam_type='dslr'): + # load cameras + with open(osp.join(sfm_dir, 'cameras.txt'), 'r') as f: + raw = f.read().splitlines()[3:] # skip header + + intrinsics = {} + for camera in tqdm(raw, position=1, leave=False): + camera = camera.split(' ') + intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]] + + # load images + with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + img_idx = {} + img_infos = {} + for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2, position=1, leave=False): + image = image.split(' ') + points = points.split(' ') + + idx = image[0] + img_name = image[-1] + assert img_name not in img_idx, 'duplicate db image: ' + img_name + img_idx[img_name] = idx # register image name + + current_points2D = {int(i): (float(x), float(y)) + for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} + img_infos[idx] = dict(intrinsics=intrinsics[int(image[-2])], + path=img_name, + frame_id=get_frame_number(img_name, cam_type), + cam_to_world=pose_from_qwxyz_txyz(image[1: -2]), + sparse_pts2d=current_points2D) + + # load 3D points + with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + points3D = {} + observations = {idx: [] for idx in img_infos.keys()} + for point in tqdm(raw, position=1, leave=False): + point = point.split() + point_3d_idx = int(point[0]) + points3D[point_3d_idx] = tuple(map(float, point[1:4])) + if len(point) > 8: + for idx, point_2d_idx in zip(point[8::2], point[9::2]): + observations[idx].append((point_3d_idx, int(point_2d_idx))) + + return img_idx, img_infos, points3D, observations + + +def subsample_img_infos(img_infos, num_images, allowed_name_subset=None): + img_infos_val = [(idx, val) for idx, val in img_infos.items()] + if allowed_name_subset is not None: + img_infos_val = [(idx, val) for idx, val in img_infos_val if val['path'] in allowed_name_subset] + + if len(img_infos_val) > num_images: + img_infos_val = sorted(img_infos_val, key=lambda x: x[1]['frame_id']) + kept_idx = np.round(np.linspace(0, len(img_infos_val) - 1, num_images)).astype(int).tolist() + img_infos_val = [img_infos_val[idx] for idx in kept_idx] + return {idx: val for idx, val in img_infos_val} + + +def undistort_images(intrinsics, rgb, mask): + camera_type = intrinsics[0] + + width = int(intrinsics[1]) + height = int(intrinsics[2]) + fx = intrinsics[3] + fy = intrinsics[4] + cx = intrinsics[5] + cy = intrinsics[6] + distortion = np.array(intrinsics[7:]) + + K = np.zeros([3, 3]) + K[0, 0] = fx + K[0, 2] = cx + K[1, 1] = fy + K[1, 2] = cy + K[2, 2] = 1 + + K = geometry.colmap_to_opencv_intrinsics(K) + if camera_type == "OPENCV_FISHEYE": + assert len(distortion) == 4 + + new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( + K, + distortion, + (width, height), + np.eye(3), + balance=0.0, + ) + # Make the cx and cy to be the center of the image + new_K[0, 2] = width / 2.0 + new_K[1, 2] = height / 2.0 + + map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) + else: + new_K, _ = cv2.getOptimalNewCameraMatrix(K, distortion, (width, height), 1, (width, height), True) + map1, map2 = cv2.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) + + undistorted_image = cv2.remap(rgb, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) + undistorted_mask = cv2.remap(mask, map1, map2, interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, borderValue=255) + K = geometry.opencv_to_colmap_intrinsics(K) + return width, height, new_K, undistorted_image, undistorted_mask + + +def process_scenes(root, pairsdir, output_dir, target_resolution): + os.makedirs(output_dir, exist_ok=True) + + # default values from + # https://github.com/scannetpp/scannetpp/blob/main/common/configs/render.yml + znear = 0.05 + zfar = 20.0 + + listfile = osp.join(pairsdir, 'scene_list.json') + with open(listfile, 'r') as f: + scenes = json.load(f) + + # for each of these, we will select some dslr images and some iphone images + # we will undistort them and render their depth + renderer = pyrender.OffscreenRenderer(0, 0) + for scene in tqdm(scenes, position=0, leave=True): + data_dir = os.path.join(root, 'data', scene) + dir_dslr = os.path.join(data_dir, 'dslr') + dir_iphone = os.path.join(data_dir, 'iphone') + dir_scans = os.path.join(data_dir, 'scans') + + assert os.path.isdir(data_dir) and os.path.isdir(dir_dslr) \ + and os.path.isdir(dir_iphone) and os.path.isdir(dir_scans) + + output_dir_scene = os.path.join(output_dir, scene) + scene_metadata_path = osp.join(output_dir_scene, 'scene_metadata.npz') + if osp.isfile(scene_metadata_path): + continue + + pairs_dir_scene = os.path.join(pairsdir, scene) + pairs_dir_scene_selected_pairs = os.path.join(pairs_dir_scene, 'selected_pairs.npz') + assert osp.isfile(pairs_dir_scene_selected_pairs) + selected_npz = np.load(pairs_dir_scene_selected_pairs) + selection, pairs = selected_npz['selection'], selected_npz['pairs'] + + # set up the output paths + output_dir_scene_rgb = os.path.join(output_dir_scene, 'images') + output_dir_scene_depth = os.path.join(output_dir_scene, 'depth') + os.makedirs(output_dir_scene_rgb, exist_ok=True) + os.makedirs(output_dir_scene_depth, exist_ok=True) + + ply_path = os.path.join(dir_scans, 'mesh_aligned_0.05.ply') + + sfm_dir_dslr = os.path.join(dir_dslr, 'colmap') + rgb_dir_dslr = os.path.join(dir_dslr, 'resized_images') + mask_dir_dslr = os.path.join(dir_dslr, 'resized_anon_masks') + + sfm_dir_iphone = os.path.join(dir_iphone, 'colmap') + rgb_dir_iphone = os.path.join(dir_iphone, 'rgb') + mask_dir_iphone = os.path.join(dir_iphone, 'rgb_masks') + + # load the mesh + with open(ply_path, 'rb') as f: + mesh_kwargs = trimesh.exchange.ply.load_ply(f) + mesh_scene = trimesh.Trimesh(**mesh_kwargs) + + # read colmap reconstruction, we will only use the intrinsics and pose here + img_idx_dslr, img_infos_dslr, points3D_dslr, observations_dslr = load_sfm(sfm_dir_dslr, cam_type='dslr') + dslr_paths = { + "in_colmap": sfm_dir_dslr, + "in_rgb": rgb_dir_dslr, + "in_mask": mask_dir_dslr, + } + + img_idx_iphone, img_infos_iphone, points3D_iphone, observations_iphone = load_sfm( + sfm_dir_iphone, cam_type='iphone') + iphone_paths = { + "in_colmap": sfm_dir_iphone, + "in_rgb": rgb_dir_iphone, + "in_mask": mask_dir_iphone, + } + + mesh = pyrender.Mesh.from_trimesh(mesh_scene, smooth=False) + pyrender_scene = pyrender.Scene() + pyrender_scene.add(mesh) + + selection_dslr = [imgname + '.JPG' for imgname in selection if imgname.startswith('DSC')] + selection_iphone = [imgname + '.jpg' for imgname in selection if imgname.startswith('frame_')] + + # resize the image to a more manageable size and render depth + for selection_cam, img_idx, img_infos, paths_data in [(selection_dslr, img_idx_dslr, img_infos_dslr, dslr_paths), + (selection_iphone, img_idx_iphone, img_infos_iphone, iphone_paths)]: + rgb_dir = paths_data['in_rgb'] + mask_dir = paths_data['in_mask'] + for imgname in tqdm(selection_cam, position=1, leave=False): + imgidx = img_idx[imgname] + img_infos_idx = img_infos[imgidx] + rgb = np.array(Image.open(os.path.join(rgb_dir, img_infos_idx['path']))) + mask = np.array(Image.open(os.path.join(mask_dir, img_infos_idx['path'][:-3] + 'png'))) + + _, _, K, rgb, mask = undistort_images(img_infos_idx['intrinsics'], rgb, mask) + + # rescale_image_depthmap assumes opencv intrinsics + intrinsics = geometry.colmap_to_opencv_intrinsics(K) + image, mask, intrinsics = rescale_image_depthmap( + rgb, mask, intrinsics, (target_resolution, target_resolution * 3.0 / 4)) + + W, H = image.size + intrinsics = geometry.opencv_to_colmap_intrinsics(intrinsics) + + # update inpace img_infos_idx + img_infos_idx['intrinsics'] = intrinsics + rgb_outpath = os.path.join(output_dir_scene_rgb, img_infos_idx['path'][:-3] + 'jpg') + image.save(rgb_outpath) + + depth_outpath = os.path.join(output_dir_scene_depth, img_infos_idx['path'][:-3] + 'png') + # render depth image + renderer.viewport_width, renderer.viewport_height = W, H + fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2] + camera = pyrender.camera.IntrinsicsCamera(fx, fy, cx, cy, znear=znear, zfar=zfar) + camera_node = pyrender_scene.add(camera, pose=img_infos_idx['cam_to_world'] @ OPENGL_TO_OPENCV) + + depth = renderer.render(pyrender_scene, flags=pyrender.RenderFlags.DEPTH_ONLY) + pyrender_scene.remove_node(camera_node) # dont forget to remove camera + + depth = (depth * 1000).astype('uint16') + # invalidate depth from mask before saving + depth_mask = (mask < 255) + depth[depth_mask] = 0 + Image.fromarray(depth).save(depth_outpath) + + trajectories = [] + intrinsics = [] + for imgname in selection: + if imgname.startswith('DSC'): + imgidx = img_idx_dslr[imgname + '.JPG'] + img_infos_idx = img_infos_dslr[imgidx] + elif imgname.startswith('frame_'): + imgidx = img_idx_iphone[imgname + '.jpg'] + img_infos_idx = img_infos_iphone[imgidx] + else: + raise ValueError('invalid image name') + + intrinsics.append(img_infos_idx['intrinsics']) + trajectories.append(img_infos_idx['cam_to_world']) + + intrinsics = np.stack(intrinsics, axis=0) + trajectories = np.stack(trajectories, axis=0) + # save metadata for this scene + np.savez(scene_metadata_path, + trajectories=trajectories, + intrinsics=intrinsics, + images=selection, + pairs=pairs) + + del img_infos + del pyrender_scene + + # concat all scene_metadata.npz into a single file + scene_data = {} + for scene_subdir in scenes: + scene_metadata_path = osp.join(output_dir, scene_subdir, 'scene_metadata.npz') + with np.load(scene_metadata_path) as data: + trajectories = data['trajectories'] + intrinsics = data['intrinsics'] + images = data['images'] + pairs = data['pairs'] + scene_data[scene_subdir] = {'trajectories': trajectories, + 'intrinsics': intrinsics, + 'images': images, + 'pairs': pairs} + + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + pairs = [] + for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): + num_imgs = data['images'].shape[0] + img_pairs = data['pairs'] + + scenes.append(scene_subdir) + sceneids.extend([scene_idx] * num_imgs) + + images.append(data['images']) + + intrinsics.append(data['intrinsics']) + trajectories.append(data['trajectories']) + + # offset pairs + img_pairs[:, 0:2] += offset + pairs.append(img_pairs) + counts.append(offset) + + offset += num_imgs + + images = np.concatenate(images, axis=0) + intrinsics = np.concatenate(intrinsics, axis=0) + trajectories = np.concatenate(trajectories, axis=0) + pairs = np.concatenate(pairs, axis=0) + np.savez(osp.join(output_dir, 'all_metadata.npz'), + counts=counts, + scenes=scenes, + sceneids=sceneids, + images=images, + intrinsics=intrinsics, + trajectories=trajectories, + pairs=pairs) + print('all done') + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + if args.pyopengl_platform.strip(): + os.environ['PYOPENGL_PLATFORM'] = args.pyopengl_platform + process_scenes(args.scannetpp_dir, args.precomputed_pairs, args.output_dir, args.target_resolution) diff --git a/monst3r/datasets_preprocess/preprocess_staticthings3d.py b/monst3r/datasets_preprocess/preprocess_staticthings3d.py new file mode 100644 index 0000000..ee3eec1 --- /dev/null +++ b/monst3r/datasets_preprocess/preprocess_staticthings3d.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the StaticThings3D dataset +# dataset at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d +# 1) Download StaticThings3D in /path/to/StaticThings3D/ +# with the script at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/scripts/download_staticthings3d.sh +# --> depths.tar.bz2 frames_finalpass.tar.bz2 poses.tar.bz2 frames_cleanpass.tar.bz2 intrinsics.tar.bz2 +# 2) unzip everything in the same /path/to/StaticThings3D/ directory +# 5) python datasets_preprocess/preprocess_staticthings3d.py --StaticThings3D_dir /path/to/tmp/StaticThings3D/ +# -------------------------------------------------------- +import os +import os.path as osp +import re +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--StaticThings3D_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/staticthings3d_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + all_scenes = _list_all_scenes(db_root) + + # crop images + args = [(db_root, osp.join(split, subsplit, seq), camera, f'{n:04d}', output_dir) + for split, subsplit, seq in all_scenes for camera in ['left', 'right'] for n in range(6, 16)] + parallel_threads(load_crop_and_save, args, star_args=True, front_num=1) + + # verify that all images are there + CAM = {b'l': 'left', b'r': 'right'} + pairs = np.load(pairs_path) + for scene, seq, cam1, im1, cam2, im2 in tqdm(pairs): + seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') + for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: + for ext in ['clean', 'final']: + impath = osp.join(output_dir, seq_path, cam, f"{idx:04n}_{ext}.jpg") + assert osp.isfile(impath), f'missing an image at {impath=}' + + print(f'>> Saved all data to {output_dir}!') + + +def load_crop_and_save(db_root, relpath_, camera, num, out_dir): + relpath = osp.join(relpath_, camera, num) + if osp.isfile(osp.join(out_dir, relpath + '.npz')): + return + os.makedirs(osp.join(out_dir, relpath_, camera), exist_ok=True) + + # load everything + intrinsics_in = readFloat(osp.join(db_root, 'intrinsics', relpath_, num + '.float3')) + cam2world = np.linalg.inv(readFloat(osp.join(db_root, 'poses', relpath + '.float3'))) + depthmap_in = readFloat(osp.join(db_root, 'depths', relpath + '.float3')) + img_clean = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_cleanpass', + relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + img_final = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_finalpass', + relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + + # do the crop + assert img_clean.shape[:2] == (540, 960) + assert img_final.shape[:2] == (540, 960) + (clean_out, final_out), depthmap, intrinsics_out, R_in2out = _crop_image( + intrinsics_in, (img_clean, img_final), depthmap_in, (512, 384)) + + # write everything + clean_out.save(osp.join(out_dir, relpath + '_clean.jpg'), quality=80) + final_out.save(osp.join(out_dir, relpath + '_final.jpg'), quality=80) + cv2.imwrite(osp.join(out_dir, relpath + '.exr'), depthmap) + + # New camera parameters + cam2world[:3, :3] = cam2world[:3, :3] @ R_in2out.T + np.savez(osp.join(out_dir, relpath + '.npz'), intrinsics=intrinsics_out, cam2world=cam2world) + + +def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(512, 512)): + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + color_image_in, depthmap_in, intrinsics_in, resolution_out) + R_in2out = np.eye(3) + return image, depthmap, intrinsics_out, R_in2out + + +def _list_all_scenes(path): + print('>> Listing all scenes') + + res = [] + for split in ['TRAIN']: + for subsplit in 'ABC': + for seq in os.listdir(osp.join(path, 'intrinsics', split, subsplit)): + res.append((split, subsplit, seq)) + print(f' (found ({len(res)}) scenes)') + assert res, f'Did not find anything at {path=}' + return res + + +def readFloat(name): + with open(name, 'rb') as f: + if (f.readline().decode("utf-8")) != 'float\n': + raise Exception('float file %s did not contain keyword' % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + data = np.fromfile(f, np.float32, count).reshape(dims) + return data # Hxw or CxHxW NxCxHxW + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.StaticThings3D_dir, args.precomputed_pairs, args.output_dir) diff --git a/monst3r/datasets_preprocess/preprocess_waymo.py b/monst3r/datasets_preprocess/preprocess_waymo.py new file mode 100644 index 0000000..49b01a6 --- /dev/null +++ b/monst3r/datasets_preprocess/preprocess_waymo.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the WayMo Open dataset +# dataset at https://github.com/waymo-research/waymo-open-dataset +# 1) Accept the license +# 2) download all training/*.tfrecord files from Perception Dataset, version 1.4.2 +# 3) put all .tfrecord files in '/path/to/waymo_dir' +# 4) install the waymo_open_dataset package with +# `python3 -m pip install gcsfs waymo-open-dataset-tf-2-12-0==1.6.4` +# 5) execute this script as `python preprocess_waymo.py --waymo_dir /path/to/waymo_dir` +# -------------------------------------------------------- +import sys +import os +import os.path as osp +import shutil +import json +from tqdm import tqdm +import PIL.Image +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import tensorflow.compat.v1 as tf +tf.enable_eager_execution() + +import path_to_root # noqa +from dust3r.utils.geometry import geotrf, inv +from dust3r.utils.image import imread_cv2 +from dust3r.utils.parallel import parallel_processes as parallel_map +from dust3r.datasets.utils import cropping +from dust3r.viz import show_raw_pointcloud + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--waymo_dir', default='data/waymo/training') + parser.add_argument('--precomputed_pairs', default='data/waymo/waymo_pairs.npz') + parser.add_argument('--output_dir', default='data/waymo_processed') + parser.add_argument('--workers', type=int, default=4) + return parser + + +def main(waymo_root, pairs_path, output_dir, workers=1): + extract_frames(waymo_root, output_dir, workers=workers) + make_crops(output_dir, workers=args.workers) + + # make sure all pairs are there + with np.load(pairs_path) as data: + scenes = data['scenes'] + frames = data['frames'] + pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) + + for scene_id, im1_id, im2_id in pairs: + for im_id in (im1_id, im2_id): + path = osp.join(output_dir, scenes[scene_id], frames[im_id] + '.jpg') + assert osp.isfile(path), f'Missing a file at {path=}\nDid you download all .tfrecord files?' + + # shutil.rmtree(osp.join(output_dir, 'tmp')) + print('Done! all data generated at', output_dir) + + +def _list_sequences(db_root): + print('>> Looking for sequences in', db_root) + res = sorted(f for f in os.listdir(db_root) if f.endswith('.tfrecord')) + print(f' found {len(res)} sequences') + return res + + +def extract_frames(db_root, output_dir, workers=8): + sequences = _list_sequences(db_root) + output_dir = osp.join(output_dir, 'tmp') + print('>> outputing result to', output_dir) + args = [(db_root, output_dir, seq) for seq in sequences] + parallel_map(process_one_seq, args, star_args=True, workers=workers) + + +def process_one_seq(db_root, output_dir, seq): + out_dir = osp.join(output_dir, seq) + os.makedirs(out_dir, exist_ok=True) + calib_path = osp.join(out_dir, 'calib.json') + if osp.isfile(calib_path): + print(f'Skipping {seq} as it is already processed') + return + + try: + with tf.device('/CPU:0'): + calib, frames = extract_frames_one_seq(osp.join(db_root, seq)) + except RuntimeError: + print(f'/!\\ Error with sequence {seq} /!\\', file=sys.stderr) + return # nothing is saved + + for f, (frame_name, views) in enumerate(tqdm(frames, leave=False)): + for cam_idx, view in views.items(): + img = PIL.Image.fromarray(view.pop('img')) + img.save(osp.join(out_dir, f'{f:05d}_{cam_idx}.jpg')) + np.savez(osp.join(out_dir, f'{f:05d}_{cam_idx}.npz'), **view) + + with open(calib_path, 'w') as f: + json.dump(calib, f) + + +def extract_frames_one_seq(filename): + from waymo_open_dataset import dataset_pb2 as open_dataset + from waymo_open_dataset.utils import frame_utils + + print('>> Opening', filename) + dataset = tf.data.TFRecordDataset(filename, compression_type='') + + calib = None + frames = [] + + for data in tqdm(dataset, leave=False): + frame = open_dataset.Frame() + # frame.ParseFromString(bytearray(data.numpy())) + frame.ParseFromString(data.numpy()) + + content = frame_utils.parse_range_image_and_camera_projection(frame) + range_images, camera_projections, _, range_image_top_pose = content + + views = {} + frames.append((frame.context.name, views)) + + # once in a sequence, read camera calibration info + if calib is None: + calib = [] + for cam in frame.context.camera_calibrations: + calib.append((cam.name, + dict(width=cam.width, + height=cam.height, + intrinsics=list(cam.intrinsic), + extrinsics=list(cam.extrinsic.transform)))) + + # convert LIDAR to pointcloud + points, cp_points = frame_utils.convert_range_image_to_point_cloud( + frame, + range_images, + camera_projections, + range_image_top_pose) + + # 3d points in vehicle frame. + points_all = np.concatenate(points, axis=0) + cp_points_all = np.concatenate(cp_points, axis=0) + + # The distance between lidar points and vehicle frame origin. + cp_points_all_tensor = tf.constant(cp_points_all, dtype=tf.int32) + + for i, image in enumerate(frame.images): + # select relevant 3D points for this view + mask = tf.equal(cp_points_all_tensor[..., 0], image.name) + cp_points_msk_tensor = tf.cast(tf.gather_nd(cp_points_all_tensor, tf.where(mask)), dtype=tf.float32) + + pose = np.asarray(image.pose.transform).reshape(4, 4) + timestamp = image.pose_timestamp + + rgb = tf.image.decode_jpeg(image.image).numpy() + + pix = cp_points_msk_tensor[..., 1:3].numpy().round().astype(np.int16) + pts3d = points_all[mask.numpy()] + + views[image.name] = dict(img=rgb, pose=pose, pixels=pix, pts3d=pts3d, timestamp=timestamp) + + if not 'show full point cloud': + show_raw_pointcloud([v['pts3d'] for v in views.values()], [v['img'] for v in views.values()]) + + return calib, frames + + +def make_crops(output_dir, workers=16, **kw): + tmp_dir = osp.join(output_dir, 'tmp') + sequences = _list_sequences(tmp_dir) + args = [(tmp_dir, output_dir, seq) for seq in sequences] + parallel_map(crop_one_seq, args, star_args=True, workers=workers, front_num=0) + + +def crop_one_seq(input_dir, output_dir, seq, resolution=512): + seq_dir = osp.join(input_dir, seq) + out_dir = osp.join(output_dir, seq) + if osp.isfile(osp.join(out_dir, '00100_1.jpg')): + return + os.makedirs(out_dir, exist_ok=True) + + # load calibration file + try: + with open(osp.join(seq_dir, 'calib.json')) as f: + calib = json.load(f) + except IOError: + print(f'/!\\ Error: Missing calib.json in sequence {seq} /!\\', file=sys.stderr) + return + + axes_transformation = np.array([ + [0, -1, 0, 0], + [0, 0, -1, 0], + [1, 0, 0, 0], + [0, 0, 0, 1]]) + + cam_K = {} + cam_distortion = {} + cam_res = {} + cam_to_car = {} + for cam_idx, cam_info in calib: + cam_idx = str(cam_idx) + cam_res[cam_idx] = (W, H) = (cam_info['width'], cam_info['height']) + f1, f2, cx, cy, k1, k2, p1, p2, k3 = cam_info['intrinsics'] + cam_K[cam_idx] = np.asarray([(f1, 0, cx), (0, f2, cy), (0, 0, 1)]) + cam_distortion[cam_idx] = np.asarray([k1, k2, p1, p2, k3]) + cam_to_car[cam_idx] = np.asarray(cam_info['extrinsics']).reshape(4, 4) # cam-to-vehicle + + frames = sorted(f[:-3] for f in os.listdir(seq_dir) if f.endswith('.jpg')) + + # from dust3r.viz import SceneViz + # viz = SceneViz() + + for frame in tqdm(frames, leave=False): + cam_idx = frame[-2] # cam index + assert cam_idx in '12345', f'bad {cam_idx=} in {frame=}' + data = np.load(osp.join(seq_dir, frame + 'npz')) + car_to_world = data['pose'] + W, H = cam_res[cam_idx] + + # load depthmap + pos2d = data['pixels'].round().astype(np.uint16) + x, y = pos2d.T + pts3d = data['pts3d'] # already in the car frame + pts3d = geotrf(axes_transformation @ inv(cam_to_car[cam_idx]), pts3d) + # X=LEFT_RIGHT y=ALTITUDE z=DEPTH + + # load image + image = imread_cv2(osp.join(seq_dir, frame + 'jpg')) + + # downscale image + output_resolution = (resolution, 1) if W > H else (1, resolution) + image, _, intrinsics2 = cropping.rescale_image_depthmap(image, None, cam_K[cam_idx], output_resolution) + image.save(osp.join(out_dir, frame + 'jpg'), quality=80) + + # save as an EXR file? yes it's smaller (and easier to load) + W, H = image.size + depthmap = np.zeros((H, W), dtype=np.float32) + pos2d = geotrf(intrinsics2 @ inv(cam_K[cam_idx]), pos2d).round().astype(np.int16) + x, y = pos2d.T + depthmap[y.clip(min=0, max=H - 1), x.clip(min=0, max=W - 1)] = pts3d[:, 2] + cv2.imwrite(osp.join(out_dir, frame + 'exr'), depthmap) + + # save camera parametes + cam2world = car_to_world @ cam_to_car[cam_idx] @ inv(axes_transformation) + np.savez(osp.join(out_dir, frame + 'npz'), intrinsics=intrinsics2, + cam2world=cam2world, distortion=cam_distortion[cam_idx]) + + # viz.add_rgbd(np.asarray(image), depthmap, intrinsics2, cam2world) + # viz.show() + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.waymo_dir, args.precomputed_pairs, args.output_dir, workers=args.workers) diff --git a/monst3r/datasets_preprocess/preprocess_wildrgbd.py b/monst3r/datasets_preprocess/preprocess_wildrgbd.py new file mode 100644 index 0000000..ff3f0f7 --- /dev/null +++ b/monst3r/datasets_preprocess/preprocess_wildrgbd.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the WildRGB-D dataset. +# Usage: +# python3 datasets_preprocess/preprocess_wildrgbd.py --wildrgbd_dir /path/to/wildrgbd +# -------------------------------------------------------- + +import argparse +import random +import json +import os +import os.path as osp + +import PIL.Image +import numpy as np +import cv2 + +from tqdm.auto import tqdm +import matplotlib.pyplot as plt + +import path_to_root # noqa +import dust3r.datasets.utils.cropping as cropping # noqa +from dust3r.utils.image import imread_cv2 + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, default="data/wildrgbd_processed") + parser.add_argument("--wildrgbd_dir", type=str, required=True) + parser.add_argument("--train_num_sequences_per_object", type=int, default=50) + parser.add_argument("--test_num_sequences_per_object", type=int, default=10) + parser.add_argument("--num_frames", type=int, default=100) + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--img_size", type=int, default=512, + help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size")) + return parser + + +def get_set_list(category_dir, split): + listfiles = ["camera_eval_list.json", "nvs_list.json"] + + sequences_all = {s: {k: set() for k in listfiles} for s in ['train', 'val']} + for listfile in listfiles: + with open(osp.join(category_dir, listfile)) as f: + subset_lists_data = json.load(f) + for s in ['train', 'val']: + sequences_all[s][listfile].update(subset_lists_data[s]) + train_intersection = set.intersection(*list(sequences_all['train'].values())) + if split == "train": + return train_intersection + else: + all_seqs = set.union(*list(sequences_all['train'].values()), *list(sequences_all['val'].values())) + return all_seqs.difference(train_intersection) + + +def prepare_sequences(category, wildrgbd_dir, output_dir, img_size, split, max_num_sequences_per_object, + output_num_frames, seed): + random.seed(seed) + category_dir = osp.join(wildrgbd_dir, category) + category_output_dir = osp.join(output_dir, category) + sequences_all = get_set_list(category_dir, split) + sequences_all = sorted(sequences_all) + + sequences_all_tmp = [] + for seq_name in sequences_all: + scene_dir = osp.join(wildrgbd_dir, category_dir, seq_name) + if not os.path.isdir(scene_dir): + print(f'{scene_dir} does not exist, skipped') + continue + sequences_all_tmp.append(seq_name) + sequences_all = sequences_all_tmp + if len(sequences_all) <= max_num_sequences_per_object: + selected_sequences = sequences_all + else: + selected_sequences = random.sample(sequences_all, max_num_sequences_per_object) + + selected_sequences_numbers_dict = {} + for seq_name in tqdm(selected_sequences, leave=False): + scene_dir = osp.join(category_dir, seq_name) + scene_output_dir = osp.join(category_output_dir, seq_name) + with open(osp.join(scene_dir, 'metadata'), 'r') as f: + metadata = json.load(f) + + K = np.array(metadata["K"]).reshape(3, 3).T + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + w, h = metadata["w"], metadata["h"] + + camera_intrinsics = np.array( + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + ) + camera_to_world_path = os.path.join(scene_dir, 'cam_poses.txt') + camera_to_world_content = np.genfromtxt(camera_to_world_path) + camera_to_world = camera_to_world_content[:, 1:].reshape(-1, 4, 4) + + frame_idx = camera_to_world_content[:, 0] + num_frames = frame_idx.shape[0] + assert num_frames >= output_num_frames + assert np.all(frame_idx == np.arange(num_frames)) + + # selected_sequences_numbers_dict[seq_name] = num_frames + + selected_frames = np.round(np.linspace(0, num_frames - 1, output_num_frames)).astype(int).tolist() + selected_sequences_numbers_dict[seq_name] = selected_frames + + for frame_id in tqdm(selected_frames): + depth_path = os.path.join(scene_dir, 'depth', f'{frame_id:0>5d}.png') + masks_path = os.path.join(scene_dir, 'masks', f'{frame_id:0>5d}.png') + rgb_path = os.path.join(scene_dir, 'rgb', f'{frame_id:0>5d}.png') + + input_rgb_image = PIL.Image.open(rgb_path).convert('RGB') + input_mask = plt.imread(masks_path) + input_depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float64) + depth_mask = np.stack((input_depthmap, input_mask), axis=-1) + H, W = input_depthmap.shape + + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = int(cx - min_margin_x), int(cy - min_margin_y) + r, b = int(cx + min_margin_x), int(cy + min_margin_y) + crop_bbox = (l, t, r, b) + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( + input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) + + # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 + scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + if max(output_resolution) < img_size: + # let's put the max dimension to img_size + scale_final = (img_size / max(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( + input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) + input_depthmap = depth_mask[:, :, 0] + input_mask = depth_mask[:, :, 1] + + camera_pose = camera_to_world[frame_id] + + # save crop images and depth, metadata + save_img_path = os.path.join(scene_output_dir, 'rgb', f'{frame_id:0>5d}.jpg') + save_depth_path = os.path.join(scene_output_dir, 'depth', f'{frame_id:0>5d}.png') + save_mask_path = os.path.join(scene_output_dir, 'masks', f'{frame_id:0>5d}.png') + os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) + + input_rgb_image.save(save_img_path) + cv2.imwrite(save_depth_path, input_depthmap.astype(np.uint16)) + cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) + + save_meta_path = os.path.join(scene_output_dir, 'metadata', f'{frame_id:0>5d}.npz') + os.makedirs(os.path.split(save_meta_path)[0], exist_ok=True) + np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, + camera_pose=camera_pose) + + return selected_sequences_numbers_dict + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + assert args.wildrgbd_dir != args.output_dir + + categories = sorted([ + dirname for dirname in os.listdir(args.wildrgbd_dir) + if os.path.isdir(os.path.join(args.wildrgbd_dir, dirname, 'scenes')) + ]) + + os.makedirs(args.output_dir, exist_ok=True) + + splits_num_sequences_per_object = [args.train_num_sequences_per_object, args.test_num_sequences_per_object] + for split, num_sequences_per_object in zip(['train', 'test'], splits_num_sequences_per_object): + selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(selected_sequences_path): + continue + all_selected_sequences = {} + for category in categories: + category_output_dir = osp.join(args.output_dir, category) + os.makedirs(category_output_dir, exist_ok=True) + category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(category_selected_sequences_path): + with open(category_selected_sequences_path, 'r') as fid: + category_selected_sequences = json.load(fid) + else: + print(f"Processing {split} - category = {category}") + category_selected_sequences = prepare_sequences( + category=category, + wildrgbd_dir=args.wildrgbd_dir, + output_dir=args.output_dir, + img_size=args.img_size, + split=split, + max_num_sequences_per_object=num_sequences_per_object, + output_num_frames=args.num_frames, + seed=args.seed + int("category".encode('ascii').hex(), 16), + ) + with open(category_selected_sequences_path, 'w') as file: + json.dump(category_selected_sequences, file) + + all_selected_sequences[category] = category_selected_sequences + with open(selected_sequences_path, 'w') as file: + json.dump(all_selected_sequences, file) diff --git a/monst3r/datasets_preprocess/scannet_sens_reader.py b/monst3r/datasets_preprocess/scannet_sens_reader.py new file mode 100644 index 0000000..4301d23 --- /dev/null +++ b/monst3r/datasets_preprocess/scannet_sens_reader.py @@ -0,0 +1,162 @@ +import argparse +import os, sys + +import os, struct +import numpy as np +import zlib +import imageio +import cv2 +import png + +COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'} +COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'} + +class RGBDFrame(): + + def load(self, file_handle): + self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4) + self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] + self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] + self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.color_data = b''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes))) + self.depth_data = b''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes))) + + + def decompress_depth(self, compression_type): + if compression_type == 'zlib_ushort': + return self.decompress_depth_zlib() + else: + raise + + + def decompress_depth_zlib(self): + return zlib.decompress(self.depth_data) + + + def decompress_color(self, compression_type): + if compression_type == 'jpeg': + return self.decompress_color_jpeg() + else: + raise + + + def decompress_color_jpeg(self): + return imageio.imread(self.color_data) + + +class SensorData: + + def __init__(self, filename): + self.version = 4 + self.load(filename) + + + def load(self, filename): + with open(filename, 'rb') as f: + version = struct.unpack('I', f.read(4))[0] + assert self.version == version + strlen = struct.unpack('Q', f.read(8))[0] + self.sensor_name = b''.join(struct.unpack('c'*strlen, f.read(strlen))) + self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]] + self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]] + self.color_width = struct.unpack('I', f.read(4))[0] + self.color_height = struct.unpack('I', f.read(4))[0] + self.depth_width = struct.unpack('I', f.read(4))[0] + self.depth_height = struct.unpack('I', f.read(4))[0] + self.depth_shift = struct.unpack('f', f.read(4))[0] + num_frames = struct.unpack('Q', f.read(8))[0] + self.frames = [] + for i in range(num_frames): + frame = RGBDFrame() + frame.load(f) + self.frames.append(frame) + + + def export_depth_images(self, output_path, image_size=None, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print('exporting', len(self.frames)//frame_skip, ' depth frames to', output_path) + for f in range(0, len(self.frames), frame_skip): + depth_data = self.frames[f].decompress_depth(self.depth_compression_type) + depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) + if image_size is not None: + depth = cv2.resize(depth, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) + #imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth) + with open(os.path.join(output_path, str(f) + '.png'), 'wb') as f: # write 16-bit + writer = png.Writer(width=depth.shape[1], height=depth.shape[0], bitdepth=16) + depth = depth.reshape(-1, depth.shape[1]).tolist() + writer.write(f, depth) + + def export_color_images(self, output_path, image_size=None, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print('exporting', len(self.frames)//frame_skip, 'color frames to', output_path) + for f in range(0, len(self.frames), frame_skip): + color = self.frames[f].decompress_color(self.color_compression_type) + if image_size is not None: + color = cv2.resize(color, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) + imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color) + + + def save_mat_to_file(self, matrix, filename): + with open(filename, 'w') as f: + for line in matrix: + np.savetxt(f, line[np.newaxis], fmt='%f') + + + def export_poses(self, output_path, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print('exporting', len(self.frames)//frame_skip, 'camera poses to', output_path) + for f in range(0, len(self.frames), frame_skip): + self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, str(f) + '.txt')) + + + def export_intrinsics(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + print('exporting camera intrinsics to', output_path) + self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt')) + self.save_mat_to_file(self.extrinsic_color, os.path.join(output_path, 'extrinsic_color.txt')) + self.save_mat_to_file(self.intrinsic_depth, os.path.join(output_path, 'intrinsic_depth.txt')) + self.save_mat_to_file(self.extrinsic_depth, os.path.join(output_path, 'extrinsic_depth.txt')) + +# params +parser = argparse.ArgumentParser() +# data paths +parser.add_argument('--filename', required=True, help='path to sens file to read') +parser.add_argument('--output_path', required=True, help='path to output folder') +parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true') +parser.add_argument('--export_color_images', dest='export_color_images', action='store_true') +parser.add_argument('--export_poses', dest='export_poses', action='store_true') +parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true') +parser.set_defaults(export_depth_images=True, export_color_images=True, export_poses=True, export_intrinsics=True) + +opt = parser.parse_args() +print(opt) + + +def main(): + if not os.path.exists(opt.output_path): + os.makedirs(opt.output_path) + # load the data + sys.stdout.write('loading %s...' % opt.filename) + sd = SensorData(opt.filename) + sys.stdout.write('loaded!\n') + if opt.export_depth_images: + sd.export_depth_images(os.path.join(opt.output_path, 'depth')) + if opt.export_color_images: + sd.export_color_images(os.path.join(opt.output_path, 'color')) + if opt.export_poses: + sd.export_poses(os.path.join(opt.output_path, 'pose')) + if opt.export_intrinsics: + sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic')) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/monst3r/datasets_preprocess/sintel_get_dynamics.py b/monst3r/datasets_preprocess/sintel_get_dynamics.py new file mode 100644 index 0000000..e116380 --- /dev/null +++ b/monst3r/datasets_preprocess/sintel_get_dynamics.py @@ -0,0 +1,170 @@ +import numpy as np +import cv2 +import os +from tqdm import tqdm +import argparse +import torch + +TAG_FLOAT = 202021.25 +def flow_read(filename): + """ Read optical flow from file, return (U,V) tuple. + + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, 'flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine?'.format(TAG_FLOAT,check) + width = np.fromfile(f,dtype=np.int32,count=1)[0] + height = np.fromfile(f,dtype=np.int32,count=1)[0] + size = width*height + assert width > 0 and height > 0 and size > 1 and size < 100000000, 'flow_read:: Invalid input size (width = {0}, height = {1}).'.format(width,height) + tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) + u = tmp[:,np.arange(width)*2] + v = tmp[:,np.arange(width)*2 + 1] + return u,v + +def cam_read(filename): + """ Read camera data, return (M,N) tuple. + + M is the intrinsic matrix, N is the extrinsic matrix, so that + + x = M*N*X, + where x is a point in homogeneous image pixel coordinates, and X is a + point in homogeneous world coordinates. + """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, 'cam_read:: Wrong tag in cam file (should be: {0}, is: {1}). Big-endian machine?'.format(TAG_FLOAT,check) + M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) + N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) + return M,N + +def depth_read(filename): + """ Read depth data from file, return as numpy array. """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, 'depth_read:: Wrong tag in depth file (should be: {0}, is: {1}). Big-endian machine?'.format(TAG_FLOAT,check) + width = np.fromfile(f,dtype=np.int32,count=1)[0] + height = np.fromfile(f,dtype=np.int32,count=1)[0] + size = width*height + assert width > 0 and height > 0 and size > 1 and size < 100000000, 'depth_read:: Invalid input size (width = {0}, height = {1}).'.format(width,height) + depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) + return depth + +def RT_to_extrinsic_matrix(R, T): + extrinsic_matrix = np.concatenate([R, T], axis=-1) + extrinsic_matrix = np.concatenate([extrinsic_matrix, np.array([[0, 0, 0, 1]])], axis=0) + return np.linalg.inv(extrinsic_matrix) + +def depth_to_3d(depth_map, intrinsic_matrix): + height, width = depth_map.shape + i, j = np.meshgrid(np.arange(width), np.arange(height)) + + # Convert pixel coordinates and depth values to 3D points + x = (i - intrinsic_matrix[0, 2]) * depth_map / intrinsic_matrix[0, 0] + y = (j - intrinsic_matrix[1, 2]) * depth_map / intrinsic_matrix[1, 1] + z = depth_map + + points_3d = np.stack([x, y, z], axis=-1) + return points_3d + +def project_3d_to_2d(points_3d, intrinsic_matrix): + # Convert 3D points to homogeneous coordinates + projected_2d_hom = intrinsic_matrix @ points_3d.T + # Convert from homogeneous coordinates to 2D image coordinates + projected_2d = projected_2d_hom[:2, :] / projected_2d_hom[2, :] + return projected_2d.T + +def compute_optical_flow(depth1, depth2, pose1, pose2, intrinsic_matrix1, intrinsic_matrix2): + # Input: All inputs as numpy arrays; convert torch tensors to numpy arrays if needed + if isinstance(depth1, torch.Tensor): + depth1 = depth1.cpu().numpy() + if isinstance(depth2, torch.Tensor): + depth2 = depth2.cpu().numpy() + if isinstance(pose1, torch.Tensor): + pose1 = pose1.cpu().numpy() + if isinstance(pose2, torch.Tensor): + pose2 = pose2.cpu().numpy() + if isinstance(intrinsic_matrix1, torch.Tensor): + intrinsic_matrix1 = intrinsic_matrix1.cpu().numpy() + if isinstance(intrinsic_matrix2, torch.Tensor): + intrinsic_matrix2 = intrinsic_matrix2.cpu().numpy() + + points_3d_frame1 = depth_to_3d(depth1, intrinsic_matrix1).reshape(-1, 3) + points_3d_frame1_hom = np.concatenate([points_3d_frame1, np.ones((points_3d_frame1.shape[0], 1))], axis=1).T + + # Calculate the transformation matrix from frame 1 to frame 2 + transformation_matrix = (pose2) @ np.linalg.inv(pose1) + points_3d_frame2_hom = transformation_matrix @ points_3d_frame1_hom + points_3d_frame2 = (points_3d_frame2_hom[:3, :]).T + + points_2d_frame1 = project_3d_to_2d(points_3d_frame1, intrinsic_matrix1) + points_2d_frame2 = project_3d_to_2d(points_3d_frame2, intrinsic_matrix2) + + # Compute optical flow vectors + optical_flow = points_2d_frame2 - points_2d_frame1 + return optical_flow + +def get_dynamic_label(base_dir, seq, continuous=False, threshold=13.75, save_dir='dynamic_label'): + depth_dir = os.path.join(base_dir, 'depth', seq) + cam_dir = os.path.join(base_dir, 'camdata_left', seq) + flow_dir = os.path.join(base_dir, 'flow', seq) + dynamic_label_dir = os.path.join(base_dir, save_dir, seq) + os.makedirs(dynamic_label_dir, exist_ok=True) + + frames = sorted([f for f in os.listdir(depth_dir) if f.endswith('.dpt')]) + for i, frame1 in enumerate(frames): + if i == len(frames) - 1: + continue + frame2 = frames[i + 1] + + frame1_id = frame1.split('.')[0] + frame2_id = frame2.split('.')[0] + + # Load depth maps + depth_map_frame1 = depth_read(os.path.join(depth_dir, frame1)) + depth_map_frame2 = depth_read(os.path.join(depth_dir, frame2)) + + # Load camera intrinsics and poses + intrinsic_matrix1, pose_frame1 = cam_read(os.path.join(cam_dir, f'{frame1_id}.cam')) + intrinsic_matrix2, pose_frame2 = cam_read(os.path.join(cam_dir, f'{frame2_id}.cam')) + + # Pad pose with [0,0,0,1] + pose_frame1 = np.concatenate([pose_frame1, np.array([[0, 0, 0, 1]])], axis=0) + pose_frame2 = np.concatenate([pose_frame2, np.array([[0, 0, 0, 1]])], axis=0) + + # Compute optical flow + optical_flow = compute_optical_flow(depth_map_frame1, depth_map_frame2, pose_frame1, pose_frame2, intrinsic_matrix1, intrinsic_matrix2) + + # Reshape the optical flow to the image dimensions + height, width = depth_map_frame1.shape + optical_flow_image = optical_flow.reshape(height, width, 2) + + # Load ground truth optical flow + u, v = flow_read(os.path.join(flow_dir, f'{frame1_id}.flo')) + gt_flow = np.stack([u, v], axis=-1) + + # Compute the error map + error_map = np.linalg.norm(gt_flow - optical_flow_image, axis=-1) + if not continuous: + binary_error_map = error_map > threshold + + # Save the binary error map + cv2.imwrite(os.path.join(dynamic_label_dir, f'{frame1_id}.png'), binary_error_map.astype(np.uint8) * 255) + else: + # Normalize the error map + error_map = error_map / error_map.max() + cv2.imwrite(os.path.join(dynamic_label_dir, f'{frame1_id}.png'), (error_map * 255).astype(np.uint8)) + +if __name__ == '__main__': + # Process all sequences + sequences = sorted(os.listdir('data/sintel/training/depth')) + base_dir = 'data/sintel/training' + parser = argparse.ArgumentParser() + parser.add_argument('--continuous', action='store_true') + parser.add_argument('--threshold', type=float, default=13.75) + parser.add_argument('--save_dir', type=str, default='dynamic_label') + args = parser.parse_args() + for seq in tqdm(sequences): + get_dynamic_label(base_dir, seq, continuous=args.continuous, threshold=args.threshold, save_dir=args.save_dir) + print(f'Finished processing {seq}') diff --git a/monst3r/datasets_preprocess/waymo_make_pairs.py b/monst3r/datasets_preprocess/waymo_make_pairs.py new file mode 100644 index 0000000..6170dfb --- /dev/null +++ b/monst3r/datasets_preprocess/waymo_make_pairs.py @@ -0,0 +1,57 @@ +import glob +import os +from tqdm import tqdm +import cv2 +import numpy as np + +import numpy as np + +file_path = "data/waymo/waymo_pairs.npz" + +data = np.load(file_path) +# data.files # ['scenes', 'frames', 'pairs'] + +scenes, frames, pairs = data['scenes'], data['frames'], data['pairs'] + +new_scenes = glob.glob("data/waymo_processed/*.tfrecord/") +new_scenes_last = [scene.split("/")[-2] for scene in new_scenes] +img_lens = [] +for path in tqdm(new_scenes): + imgs = glob.glob(path + "/*.jpg") + img_lens.append(len(imgs)) + +new_frames = list(frames) +new_pairs = [] +strides = [1,2,3,4,5,6,7,8,9] +step = 1 +for path in tqdm(new_scenes): + imgs_track1 = glob.glob(path + "/*_1.jpg") + imgs_track1.sort() + imgs_track2 = glob.glob(path + "/*_2.jpg") + imgs_track2.sort() + imgs_track3 = glob.glob(path + "/*_3.jpg") + imgs_track3.sort() + imgs_track4 = glob.glob(path + "/*_4.jpg") + imgs_track4.sort() + imgs_track5 = glob.glob(path + "/*_5.jpg") + imgs_track5.sort() + for stride in strides: + for i in range(0, len(imgs_track1)-stride, step): + if os.path.exists(imgs_track1[i+stride]) and os.path.exists(imgs_track1[i]): + new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track1[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track1[i+stride].split('/')[-1].replace('.jpg',''))]) + for i in range(0, len(imgs_track2)-stride, step): + if os.path.exists(imgs_track2[i+stride]) and os.path.exists(imgs_track2[i]): + new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track2[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track2[i+stride].split('/')[-1].replace('.jpg',''))]) + for i in range(0, len(imgs_track3)-stride, step): + if os.path.exists(imgs_track3[i+stride]) and os.path.exists(imgs_track3[i]): + new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track3[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track3[i+stride].split('/')[-1].replace('.jpg',''))]) + for i in range(0, len(imgs_track4)-stride, step): + if os.path.exists(imgs_track4[i+stride]) and os.path.exists(imgs_track4[i]): + new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track4[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track4[i+stride].split('/')[-1].replace('.jpg',''))]) + for i in range(0, len(imgs_track5)-stride, step): + if os.path.exists(imgs_track5[i+stride]) and os.path.exists(imgs_track5[i]): + new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track5[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track5[i+stride].split('/')[-1].replace('.jpg',''))]) + +print(len(new_pairs), "pairs") +save_path = "data/waymo_processed/waymo_pairs_video.npz" +np.savez(save_path, scenes=np.array(new_scenes_last), frames=np.array(new_frames), pairs=np.array(new_pairs)) \ No newline at end of file diff --git a/monst3r/demo.py b/monst3r/demo.py new file mode 100644 index 0000000..70b1bf7 --- /dev/null +++ b/monst3r/demo.py @@ -0,0 +1,450 @@ +# -------------------------------------------------------- +# gradio demo +# -------------------------------------------------------- + +import argparse +import math +import gradio +import os +import torch +import numpy as np +import tempfile +import functools +import copy + +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images, rgb, enlarge_seg_masks +from dust3r.utils.device import to_numpy +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.utils.viz_demo import convert_scene_output_to_glb, get_dynamic_mask_from_pairviewer +import matplotlib.pyplot as pl +import cv2 + + +# DEBUG +from dust3r.utils.viz_demo import render_colored_pointclouds, render_colored_pointclouds_pytorch3d +# DEBUG + +pl.ion() +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser_url = parser.add_mutually_exclusive_group() + parser_url.add_argument("--local_network", action='store_true', default=False, + help="make app accessible on local network: address will be set to 0.0.0.0") + parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1") + parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") + parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). " + "If None, will search for an available port starting at 7860."), + default=None) + parser.add_argument("--weights", type=str, help="path to the model weights", default='checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth') + parser.add_argument("--model_name", type=str, default='Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt', help="model name") + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--output_dir", type=str, default='./demo_tmp', help="value for tempfile.tempdir") + parser.add_argument("--silent", action='store_true', default=False, + help="silence logs") + parser.add_argument("--input_dir", type=str, help="Path to input images directory", default=None) + parser.add_argument("--seq_name", type=str, help="Sequence name for evaluation", default='NULL') + parser.add_argument('--use_gt_davis_masks', action='store_true', default=False, help='Use ground truth masks for DAVIS') + parser.add_argument('--not_batchify', action='store_true', default=False, help='Use non batchify mode for global optimization') + parser.add_argument('--real_time', action='store_true', default=False, help='Realtime mode') + parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') + + parser.add_argument('--fps', type=int, default=0, help='FPS for video processing') + parser.add_argument('--num_frames', type=int, default=200, help='Maximum number of frames for video processing') + + # Add "share" argument if you want to make the demo accessible on the public internet + parser.add_argument("--share", action='store_true', default=False, help="Share the demo") + return parser + +def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, + clean_depth=False, transparent_cams=False, cam_size=0.05, show_cam=True, save_name=None, thr_for_init_conf=True): + """ + extract 3D_model (glb file) from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + focals = scene.get_focals().cpu() + cams2world = scene.get_im_poses().cpu() + # 3D pointcloud from depthmap, poses and intrinsics + pts3d = to_numpy(scene.get_pts3d(raw_pts=True)) + + # DEBUG: save the pts3d frame by frame, later visualize with cloudcompare + # import trimesh + # for i in range(len(pts3d)): + # pc = trimesh.PointCloud(pts3d[i].reshape(-1, 3)) + # pc.export(f'{outdir}/pts3d_frame_{i}.ply') + + scene.min_conf_thr = min_conf_thr + scene.thr_for_init_conf = thr_for_init_conf + msk = to_numpy(scene.get_masks()) + cmap = pl.get_cmap('viridis') + cam_color = [cmap(i/len(rgbimg))[:3] for i in range(len(rgbimg))] + cam_color = [(255*c[0], 255*c[1], 255*c[2]) for c in cam_color] + + # DEBUG, render images + pp = scene.get_principal_points().cpu() + import pdb; pdb.set_trace() + render_colored_pointclouds(outdir, rgbimg[0], pts3d[0], msk[0], focals, pp, cams2world[0:1]) + render_colored_pointclouds_pytorch3d(outdir, rgbimg[0], pts3d[0], msk[0], focals[0], pp[0], cams2world[0:1]) + + return convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, + transparent_cams=transparent_cams, cam_size=cam_size, show_cam=show_cam, silent=silent, save_name=save_name, + cam_color=cam_color) + + +def get_reconstructed_scene(args, outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, show_cam, scenegraph_type, winsize, refid, + seq_name, new_model_weights, temporal_smoothing_weight, translation_weight, shared_focal, + flow_loss_weight, flow_loss_start_iter, flow_loss_threshold, use_gt_mask, fps, num_frames): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + translation_weight = float(translation_weight) + if new_model_weights != args.weights: + model = AsymmetricCroCo3DStereo.from_pretrained(new_model_weights).to(device) + model.eval() + if seq_name != "NULL": + dynamic_mask_path = f'data/davis/DAVIS/masked_images/480p/{seq_name}' + else: + dynamic_mask_path = None + imgs = load_images(filelist, size=image_size, verbose=not silent, dynamic_mask_root=dynamic_mask_path, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + if scenegraph_type == "swin" or scenegraph_type == "swinstride" or scenegraph_type == "swin2stride": + scenegraph_type = scenegraph_type + "-" + str(winsize) + "-noncyclic" + elif scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + if len(imgs) > 2: + mode = GlobalAlignerMode.PointCloudOptimizer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent, shared_focal = shared_focal, temporal_smoothing_weight=temporal_smoothing_weight, translation_weight=translation_weight, + flow_loss_weight=flow_loss_weight, flow_loss_start_epoch=flow_loss_start_iter, flow_loss_thre=flow_loss_threshold, use_self_mask=not use_gt_mask, + num_total_iter=niter, empty_cache= len(filelist) > 72, batchify=not args.not_batchify) + else: + mode = GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent) + lr = 0.01 + + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + outfile = get_3D_model_from_scene(save_folder, silent, scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam) + + poses = scene.save_tum_poses(f'{save_folder}/pred_traj.txt') + K = scene.save_intrinsics(f'{save_folder}/pred_intrinsics.txt') + depth_maps = scene.save_depth_maps(save_folder) + dynamic_masks = scene.save_dynamic_masks(save_folder) + conf = scene.save_conf_maps(save_folder) + init_conf = scene.save_init_conf_maps(save_folder) + rgbs = scene.save_rgb_imgs(save_folder) + enlarge_seg_masks(save_folder, kernel_size=5 if use_gt_mask else 3) + + # also return rgb, depth and confidence imgs + # depth is normalized with the max value for all images + # we apply the jet colormap on the confidence maps + rgbimg = scene.imgs + depths = to_numpy(scene.get_depthmaps()) + confs = to_numpy([c for c in scene.im_conf]) + init_confs = to_numpy([c for c in scene.init_conf_maps]) + cmap = pl.get_cmap('jet') + depths_max = max([d.max() for d in depths]) + depths = [cmap(d/depths_max) for d in depths] + confs_max = max([d.max() for d in confs]) + confs = [cmap(d/confs_max) for d in confs] + init_confs_max = max([d.max() for d in init_confs]) + init_confs = [cmap(d/init_confs_max) for d in init_confs] + + imgs = [] + for i in range(len(rgbimg)): + imgs.append(rgbimg[i]) + imgs.append(rgb(depths[i])) + imgs.append(rgb(confs[i])) + imgs.append(rgb(init_confs[i])) + + # if two images, and the shape is same, we can compute the dynamic mask + if len(rgbimg) == 2 and rgbimg[0].shape == rgbimg[1].shape: + motion_mask_thre = 0.35 + error_map = get_dynamic_mask_from_pairviewer(scene, both_directions=True, output_dir=args.output_dir, motion_mask_thre=motion_mask_thre) + # imgs.append(rgb(error_map)) + # apply threshold on the error map + normalized_error_map = (error_map - error_map.min()) / (error_map.max() - error_map.min()) + error_map_max = normalized_error_map.max() + error_map = cmap(normalized_error_map/error_map_max) + imgs.append(rgb(error_map)) + binary_error_map = (normalized_error_map > motion_mask_thre).astype(np.uint8) + imgs.append(rgb(binary_error_map*255)) + + return scene, outfile, imgs + + +def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type): + # if inputfiles[0] is a video, set the num_files to 200 + if inputfiles is not None and len(inputfiles) == 1 and inputfiles[0].name.endswith(('.mp4', '.avi', '.mov', '.MP4', '.AVI', '.MOV')): + num_files = 200 + else: + num_files = len(inputfiles) if inputfiles is not None else 1 + max_winsize = max(1, math.ceil((num_files-1)/2)) + if scenegraph_type == "swin" or scenegraph_type == "swin2stride" or scenegraph_type == "swinstride": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=min(max_winsize,5), + minimum=1, maximum=max_winsize, step=1, visible=True) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=False) + elif scenegraph_type == "oneref": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=True) + else: + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=False) + return winsize, refid + + +def get_reconstructed_scene_realtime(args, model, device, silent, image_size, filelist, scenegraph_type, refid, seq_name, fps, num_frames): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + model.eval() + imgs = load_images(filelist, size=image_size, verbose=not silent, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + + if scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + elif scenegraph_type == "oneref_mid": + scenegraph_type = "oneref-" + str(len(imgs) // 2) + else: + raise ValueError(f"Unknown scenegraph type for realtime mode: {scenegraph_type}") + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=False) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + + + view1, view2, pred1, pred2 = output['view1'], output['view2'], output['pred1'], output['pred2'] + pts1 = pred1['pts3d'].detach().cpu().numpy() + pts2 = pred2['pts3d_in_other_view'].detach().cpu().numpy() + for batch_idx in range(len(view1['img'])): + colors1 = rgb(view1['img'][batch_idx]) + colors2 = rgb(view2['img'][batch_idx]) + xyzrgb1 = np.concatenate([pts1[batch_idx], colors1], axis=-1) #(H, W, 6) + xyzrgb2 = np.concatenate([pts2[batch_idx], colors2], axis=-1) + np.save(save_folder + '/pts3d1_p' + str(batch_idx) + '.npy', xyzrgb1) + np.save(save_folder + '/pts3d2_p' + str(batch_idx) + '.npy', xyzrgb2) + + conf1 = pred1['conf'][batch_idx].detach().cpu().numpy() + conf2 = pred2['conf'][batch_idx].detach().cpu().numpy() + np.save(save_folder + '/conf1_p' + str(batch_idx) + '.npy', conf1) + np.save(save_folder + '/conf2_p' + str(batch_idx) + '.npy', conf2) + + # save the imgs of two views + img1_rgb = cv2.cvtColor(colors1 * 255, cv2.COLOR_BGR2RGB) + img2_rgb = cv2.cvtColor(colors2 * 255, cv2.COLOR_BGR2RGB) + cv2.imwrite(save_folder + '/img1_p' + str(batch_idx) + '.png', img1_rgb) + cv2.imwrite(save_folder + '/img2_p' + str(batch_idx) + '.png', img2_rgb) + + return save_folder + + +def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, args=None): + recon_fun = functools.partial(get_reconstructed_scene, args, tmpdirname, model, device, silent, image_size) + model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent) + with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="MonST3R Demo") as demo: + # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference + scene = gradio.State(None) + gradio.HTML(f'

MonST3R Demo

') + with gradio.Column(): + inputfiles = gradio.File(file_count="multiple") + with gradio.Row(): + schedule = gradio.Dropdown(["linear", "cosine"], + value='linear', label="schedule", info="For global alignment!") + niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000, + label="num_iterations", info="For global alignment!") + seq_name = gradio.Textbox(label="Sequence Name", placeholder="NULL", value=args.seq_name, info="For evaluation") + scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref", "swinstride", "swin2stride"], + value='swinstride', label="Scenegraph", + info="Define how to make pairs", + interactive=True) + winsize = gradio.Slider(label="Scene Graph: Window Size", value=5, + minimum=1, maximum=1, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) + + run_btn = gradio.Button("Run") + + with gradio.Row(): + # adjust the confidence thresholdx + min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.1, minimum=0.0, maximum=20, step=0.01) + # adjust the camera size in the output pointcloud + cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001) + # adjust the temporal smoothing weight + temporal_smoothing_weight = gradio.Slider(label="temporal_smoothing_weight", value=0.01, minimum=0.0, maximum=0.1, step=0.001) + # add translation weight + translation_weight = gradio.Textbox(label="translation_weight", placeholder="1.0", value="1.0", info="For evaluation") + # change to another model + new_model_weights = gradio.Textbox(label="New Model", placeholder=args.weights, value=args.weights, info="Path to updated model weights") + with gradio.Row(): + as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud") + # two post process implemented + mask_sky = gradio.Checkbox(value=False, label="Mask sky") + clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") + transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") + # not to show camera + show_cam = gradio.Checkbox(value=True, label="Show Camera") + shared_focal = gradio.Checkbox(value=True, label="Shared Focal Length") + use_davis_gt_mask = gradio.Checkbox(value=False, label="Use GT Mask (DAVIS)") + with gradio.Row(): + flow_loss_weight = gradio.Slider(label="Flow Loss Weight", value=0.01, minimum=0.0, maximum=1.0, step=0.001) + flow_loss_start_iter = gradio.Slider(label="Flow Loss Start Iter", value=0.1, minimum=0, maximum=1, step=0.01) + flow_loss_threshold = gradio.Slider(label="Flow Loss Threshold", value=25, minimum=0, maximum=100, step=1) + # for video processing + fps = gradio.Slider(label="FPS", value=0, minimum=0, maximum=60, step=1) + num_frames = gradio.Slider(label="Num Frames", value=100, minimum=0, maximum=200, step=1) + + outmodel = gradio.Model3D() + outgallery = gradio.Gallery(label='rgb,depth,confidence, init_conf', columns=4, height="100%") + + # events + scenegraph_type.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + inputfiles.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + run_btn.click(fn=recon_fun, + inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, + mask_sky, clean_depth, transparent_cams, cam_size, show_cam, + scenegraph_type, winsize, refid, seq_name, new_model_weights, + temporal_smoothing_weight, translation_weight, shared_focal, + flow_loss_weight, flow_loss_start_iter, flow_loss_threshold, use_davis_gt_mask, + fps, num_frames], + outputs=[scene, outmodel, outgallery]) + min_conf_thr.release(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + cam_size.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + as_pointcloud.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + mask_sky.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + clean_depth.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + transparent_cams.change(model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + demo.launch(share=args.share, server_name=server_name, server_port=server_port) + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + + if args.output_dir is not None: + tmp_path = args.output_dir + os.makedirs(tmp_path, exist_ok=True) + tempfile.tempdir = tmp_path + + if args.server_name is not None: + server_name = args.server_name + else: + server_name = '0.0.0.0' if args.local_network else '127.0.0.1' + + if args.weights is not None and os.path.exists(args.weights): + weights_path = args.weights + else: + weights_path = args.model_name + + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + # Use the provided output_dir or create a temporary directory + tmpdirname = args.output_dir if args.output_dir is not None else tempfile.mkdtemp(suffix='monst3r_gradio_demo') + + if not args.silent: + print('Outputting stuff in', tmpdirname) + + if args.input_dir is not None: + # Process images in the input directory with default parameters + if os.path.isdir(args.input_dir): # input_dir is a directory of images + input_files = [os.path.join(args.input_dir, fname) for fname in sorted(os.listdir(args.input_dir))] + else: # input_dir is a video + input_files = [args.input_dir] + + if args.real_time: + recon_fun = functools.partial(get_reconstructed_scene_realtime, args, model, args.device, args.silent, args.image_size) + outfile = recon_fun( + filelist=input_files, + scenegraph_type='oneref_mid', + refid=0, + seq_name=args.seq_name, + fps=args.fps, + num_frames=args.num_frames, + ) + else: + recon_fun = functools.partial(get_reconstructed_scene, args, tmpdirname, model, args.device, args.silent, args.image_size) + # Call the function with default parameters + scene, outfile, imgs = recon_fun( + filelist=input_files, + schedule='linear', + niter=300, + min_conf_thr=1.1, + as_pointcloud=True, + mask_sky=False, + clean_depth=True, + transparent_cams=False, + cam_size=0.05, + show_cam=True, + scenegraph_type='swinstride', + winsize=5, + refid=0, + seq_name=args.seq_name, + new_model_weights=args.weights, + temporal_smoothing_weight=0.01, + translation_weight='1.0', + shared_focal=True, + flow_loss_weight=0.01, + flow_loss_start_iter=0.1, + flow_loss_threshold=25, + use_gt_mask=args.use_gt_davis_masks, + fps=args.fps, + num_frames=args.num_frames, + ) + print(f"Processing completed. Output saved in {tmpdirname}/{args.seq_name}") + else: + # Launch Gradio demo + main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent, args=args) diff --git a/monst3r/demo_render.py b/monst3r/demo_render.py new file mode 100644 index 0000000..9bedf4c --- /dev/null +++ b/monst3r/demo_render.py @@ -0,0 +1,608 @@ +# -------------------------------------------------------- +# gradio demo +# -------------------------------------------------------- + +import argparse +from fileinput import fileno +import math +from struct import calcsize +from altair import Theta +import gradio +import os +from sympy import ask +import torch +import numpy as np +import tempfile +import functools +import copy + +# from trimesh import PointCloud + +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images, rgb, enlarge_seg_masks +from dust3r.utils.device import to_numpy +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.utils.viz_demo import convert_scene_output_to_glb, get_dynamic_mask_from_pairviewer +import matplotlib.pyplot as pl +import cv2 + +from pvd_utils import world_point_to_obj, sphere2pose, generate_traj_specified, setup_renderer, save_video, save_frames + +pl.ion() +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +## render imports +from pytorch3d.renderer import PerspectiveCameras +from pytorch3d.structures import Pointclouds +from PIL import Image +import torchvision +import torch.nn.functional as F + + +def render_pcd(pts3d,imgs,masks,views,renderer,device='cpu',nbv=False): + import pdb; pdb.set_trace() + imgs = to_numpy(imgs) + pts3d = to_numpy(pts3d) + + if masks == None: + pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) + col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) + else: + # masks = to_numpy(masks) + pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) + col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) + + point_cloud = Pointclouds(points=[pts], features=[col]).extend(views) + images = renderer(point_cloud) + + if nbv: + color_mask = torch.ones(col.shape).to(device) + point_cloud_mask = Pointclouds(points=[pts],features=[color_mask]).extend(views) + view_masks = renderer(point_cloud_mask) + else: + view_masks = None + + return images, view_masks + +def run_render(pcd, imgs,masks, H, W, camera_traj,num_views,device='cpu', nbv=False): + render_setup = setup_renderer(camera_traj, image_size=(H,W)) + renderer = render_setup['renderer'] + render_results, viewmask = render_pcd(pcd, imgs, masks, num_views,renderer,device=device,nbv=nbv) + return render_results, viewmask + +def generate_traj_from_single_c2w(c2ws_anchor,H,W,fs,c,theta=0., phi=0.,d_r=0.,d_x=0.,d_y=0.,frame=1,device='cuda'): + # Initialize a camera. + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + c2w_new = sphere2pose(c2ws_anchor, np.float32(theta), np.float32(phi), np.float32(d_r), device, np.float32(d_x),np.float32(d_y)) + c2ws = c2w_new + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,num_views + + +def generate_traj_from_c2ws(c2ws_anchor,H,W,fs,c,frame,device): + # Initialize a camera. + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + theta = 0 + phi = 0 + r = 1 + x = 0 + y = 0 + + c2ws_list = [] + for i in range(frame): + c2w_new = sphere2pose(c2ws_anchor[i:i+1], np.float32(theta), np.float32(phi), np.float32(r), device, np.float32(x),np.float32(y)) + c2ws_list.append(c2w_new) + c2ws = torch.cat(c2ws_list,dim=0) + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,num_views + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser_url = parser.add_mutually_exclusive_group() + parser_url.add_argument("--local_network", action='store_true', default=False, + help="make app accessible on local network: address will be set to 0.0.0.0") + parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1") + parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") + parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). " + "If None, will search for an available port starting at 7860."), + default=None) + parser.add_argument("--weights", type=str, help="path to the model weights", default='checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth') + parser.add_argument("--model_name", type=str, default='Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt', help="model name") + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--output_dir", type=str, default='./demo_tmp', help="value for tempfile.tempdir") + parser.add_argument("--silent", action='store_true', default=False, + help="silence logs") + parser.add_argument("--input_dir", type=str, help="Path to input images directory", default=None) + parser.add_argument("--seq_name", type=str, help="Sequence name for evaluation", default='NULL') + parser.add_argument('--use_gt_davis_masks', action='store_true', default=False, help='Use ground truth masks for DAVIS') + parser.add_argument('--not_batchify', action='store_true', default=False, help='Use non batchify mode for global optimization') + parser.add_argument('--real_time', action='store_true', default=False, help='Realtime mode') + parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') + + parser.add_argument('--fps', type=int, default=0, help='FPS for video processing') + parser.add_argument('--num_frames', type=int, default=200, help='Maximum number of frames for video processing') + + # Add "share" argument if you want to make the demo accessible on the public internet + parser.add_argument("--share", action='store_true', default=False, help="Share the demo") + return parser + +def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, + clean_depth=False, transparent_cams=False, cam_size=0.05, show_cam=True, save_name=None, thr_for_init_conf=True): + """ + extract 3D_model (glb file) from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + focals = scene.get_focals().cpu() + cams2world = scene.get_im_poses().cpu() + # 3D pointcloud from depthmap, poses and intrinsics + pts3d = to_numpy(scene.get_pts3d(raw_pts=True)) + scene.min_conf_thr = min_conf_thr + scene.thr_for_init_conf = thr_for_init_conf + msk = to_numpy(scene.get_masks()) + cmap = pl.get_cmap('viridis') + cam_color = [cmap(i/len(rgbimg))[:3] for i in range(len(rgbimg))] + cam_color = [(255*c[0], 255*c[1], 255*c[2]) for c in cam_color] + + return convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, + transparent_cams=transparent_cams, cam_size=cam_size, show_cam=show_cam, silent=silent, save_name=save_name, + cam_color=cam_color) + +def render_2D_images_from_scene(outdir, scene, min_conf_thr=3, mask_sky=False, + clean_depth=False, thr_for_init_conf=True): + """ + render 2D images from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + # scene.min_conf_thr = min_conf_thr + # scene.thr_for_init_conf = thr_for_init_conf + # msk = to_numpy(scene.get_masks()) + + # DEBUG, render images + elevation = 5.0 + center_scale = 1.0 + d_x = torch.tensor(0.0) + d_y = torch.tensor(0.0) + d_r = torch.tensor(0.0) + d_theta = torch.tensor(0.0) + d_phi = torch.tensor(0.0) + dpt_trd = torch.tensor(1.0) + device = torch.device('cuda') + H, W = rgbimg[0].shape[:2] + + c2ws = scene.get_im_poses().detach() + pcd = [i.detach() for i in scene.get_pts3d(clip_thred=dpt_trd, raw_pts=True)] # a list of points of size whc + focals = scene.get_focals().detach() + principal_points = scene.get_principal_points().detach() + depth = to_numpy(scene.get_depthmaps()) + depth_avg = depth[0][H//2,W//2] + radius = depth_avg*center_scale + + c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=0, r=radius, elevation=elevation, device=device) + + # import pdb; pdb.set_trace() + # device = torch.device('cuda') + # camera_traj,num_views = generate_traj_specified(c2ws[0:1], H, W, focals[0:1], principal_points[0:1], d_theta, d_phi, d_r,d_x*depth_avg/focals[0].item(),d_y*depth_avg/focals[0].item(),1, device) + # camera_traj,num_views = generate_traj_from_c2ws(c2ws, H, W, focals, principal_points, len(c2ws), device) + + + # Apply the fixed view transformation to all camera poses + camera_poses = [] + for i in range(len(c2ws)): + # Generate single transformed pose + single_pose, _ = generate_traj_from_single_c2w( + c2ws[i:i+1], # Take one pose at a time + H, W, + focals[0:1], # Use first frame's focal length for all + principal_points[0:1], # Use first frame's principal point for all + frame=1, # Only generate one frame per pose + device=device + ) + camera_poses.append(single_pose) + + # Combine all transformed cameras + camera_traj = camera_poses[0] + for cam in camera_poses[1:]: + camera_traj.R = torch.cat([camera_traj.R, cam.R]) + camera_traj.T = torch.cat([camera_traj.T, cam.T]) + num_views = camera_traj.R.shape[0] + + render_results, viewmask = run_render(pcd[0:1], rgbimg[0:1], ask[0:1], H, W, camera_traj,num_views, device=device, nbv=False) + render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) + save_frames(render_results, outdir, save_names=[f'render_{i}' for i in range(len(render_results))], folder=outdir) + save_video(render_results, os.path.join(outdir, 'render.mp4')) + + return render_results + +def get_reconstructed_scene(args, outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, show_cam, scenegraph_type, winsize, refid, + seq_name, new_model_weights, temporal_smoothing_weight, translation_weight, shared_focal, + flow_loss_weight, flow_loss_start_iter, flow_loss_threshold, use_gt_mask, fps, num_frames): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + translation_weight = float(translation_weight) + if new_model_weights != args.weights: + model = AsymmetricCroCo3DStereo.from_pretrained(new_model_weights).to(device) + model.eval() + if seq_name != "NULL": + dynamic_mask_path = f'data/davis/DAVIS/masked_images/480p/{seq_name}' + else: + dynamic_mask_path = None + imgs = load_images(filelist, size=image_size, verbose=not silent, dynamic_mask_root=dynamic_mask_path, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + if scenegraph_type == "swin" or scenegraph_type == "swinstride" or scenegraph_type == "swin2stride": + scenegraph_type = scenegraph_type + "-" + str(winsize) + "-noncyclic" + elif scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + if len(imgs) > 2: + mode = GlobalAlignerMode.PointCloudOptimizer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent, shared_focal = shared_focal, temporal_smoothing_weight=temporal_smoothing_weight, translation_weight=translation_weight, + flow_loss_weight=flow_loss_weight, flow_loss_start_epoch=flow_loss_start_iter, flow_loss_thre=flow_loss_threshold, use_self_mask=not use_gt_mask, + num_total_iter=niter, empty_cache= len(filelist) > 72, batchify=not args.not_batchify) + else: + mode = GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent) + lr = 0.01 + + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + # outfile = get_3D_model_from_scene(save_folder, silent, scene, min_conf_thr, as_pointcloud, mask_sky, + # clean_depth, transparent_cams, cam_size, show_cam) + + outfile = None + + render_results = render_2D_images_from_scene(save_folder, scene, min_conf_thr, mask_sky, + clean_depth, thr_for_init_conf=True) + + poses = scene.save_tum_poses(f'{save_folder}/pred_traj.txt') + K = scene.save_intrinsics(f'{save_folder}/pred_intrinsics.txt') + depth_maps = scene.save_depth_maps(save_folder) + dynamic_masks = scene.save_dynamic_masks(save_folder) + conf = scene.save_conf_maps(save_folder) + init_conf = scene.save_init_conf_maps(save_folder) + rgbs = scene.save_rgb_imgs(save_folder) + enlarge_seg_masks(save_folder, kernel_size=5 if use_gt_mask else 3) + + # also return rgb, depth and confidence imgs + # depth is normalized with the max value for all images + # we apply the jet colormap on the confidence maps + rgbimg = scene.imgs + depths = to_numpy(scene.get_depthmaps()) + confs = to_numpy([c for c in scene.im_conf]) + init_confs = to_numpy([c for c in scene.init_conf_maps]) + cmap = pl.get_cmap('jet') + depths_max = max([d.max() for d in depths]) + depths = [cmap(d/depths_max) for d in depths] + confs_max = max([d.max() for d in confs]) + confs = [cmap(d/confs_max) for d in confs] + init_confs_max = max([d.max() for d in init_confs]) + init_confs = [cmap(d/init_confs_max) for d in init_confs] + + imgs = [] + for i in range(len(rgbimg)): + imgs.append(rgbimg[i]) + imgs.append(rgb(depths[i])) + imgs.append(rgb(confs[i])) + imgs.append(rgb(init_confs[i])) + + # if two images, and the shape is same, we can compute the dynamic mask + if len(rgbimg) == 2 and rgbimg[0].shape == rgbimg[1].shape: + motion_mask_thre = 0.35 + error_map = get_dynamic_mask_from_pairviewer(scene, both_directions=True, output_dir=args.output_dir, motion_mask_thre=motion_mask_thre) + # imgs.append(rgb(error_map)) + # apply threshold on the error map + normalized_error_map = (error_map - error_map.min()) / (error_map.max() - error_map.min()) + error_map_max = normalized_error_map.max() + error_map = cmap(normalized_error_map/error_map_max) + imgs.append(rgb(error_map)) + binary_error_map = (normalized_error_map > motion_mask_thre).astype(np.uint8) + imgs.append(rgb(binary_error_map*255)) + + return scene, outfile, imgs + + +def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type): + # if inputfiles[0] is a video, set the num_files to 200 + if inputfiles is not None and len(inputfiles) == 1 and inputfiles[0].name.endswith(('.mp4', '.avi', '.mov', '.MP4', '.AVI', '.MOV')): + num_files = 200 + else: + num_files = len(inputfiles) if inputfiles is not None else 1 + max_winsize = max(1, math.ceil((num_files-1)/2)) + if scenegraph_type == "swin" or scenegraph_type == "swin2stride" or scenegraph_type == "swinstride": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=min(max_winsize,5), + minimum=1, maximum=max_winsize, step=1, visible=True) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=False) + elif scenegraph_type == "oneref": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=True) + else: + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=False) + return winsize, refid + + +def get_reconstructed_scene_realtime(args, model, device, silent, image_size, filelist, scenegraph_type, refid, seq_name, fps, num_frames): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + model.eval() + imgs = load_images(filelist, size=image_size, verbose=not silent, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + + if scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + elif scenegraph_type == "oneref_mid": + scenegraph_type = "oneref-" + str(len(imgs) // 2) + else: + raise ValueError(f"Unknown scenegraph type for realtime mode: {scenegraph_type}") + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=False) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + + + view1, view2, pred1, pred2 = output['view1'], output['view2'], output['pred1'], output['pred2'] + pts1 = pred1['pts3d'].detach().cpu().numpy() + pts2 = pred2['pts3d_in_other_view'].detach().cpu().numpy() + for batch_idx in range(len(view1['img'])): + colors1 = rgb(view1['img'][batch_idx]) + colors2 = rgb(view2['img'][batch_idx]) + xyzrgb1 = np.concatenate([pts1[batch_idx], colors1], axis=-1) #(H, W, 6) + xyzrgb2 = np.concatenate([pts2[batch_idx], colors2], axis=-1) + np.save(save_folder + '/pts3d1_p' + str(batch_idx) + '.npy', xyzrgb1) + np.save(save_folder + '/pts3d2_p' + str(batch_idx) + '.npy', xyzrgb2) + + conf1 = pred1['conf'][batch_idx].detach().cpu().numpy() + conf2 = pred2['conf'][batch_idx].detach().cpu().numpy() + np.save(save_folder + '/conf1_p' + str(batch_idx) + '.npy', conf1) + np.save(save_folder + '/conf2_p' + str(batch_idx) + '.npy', conf2) + + # save the imgs of two views + img1_rgb = cv2.cvtColor(colors1 * 255, cv2.COLOR_BGR2RGB) + img2_rgb = cv2.cvtColor(colors2 * 255, cv2.COLOR_BGR2RGB) + cv2.imwrite(save_folder + '/img1_p' + str(batch_idx) + '.png', img1_rgb) + cv2.imwrite(save_folder + '/img2_p' + str(batch_idx) + '.png', img2_rgb) + + return save_folder + + +def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, args=None): + recon_fun = functools.partial(get_reconstructed_scene, args, tmpdirname, model, device, silent, image_size) + model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent) + with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="MonST3R Demo") as demo: + # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference + scene = gradio.State(None) + gradio.HTML(f'

MonST3R Demo

') + with gradio.Column(): + inputfiles = gradio.File(file_count="multiple") + with gradio.Row(): + schedule = gradio.Dropdown(["linear", "cosine"], + value='linear', label="schedule", info="For global alignment!") + niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000, + label="num_iterations", info="For global alignment!") + seq_name = gradio.Textbox(label="Sequence Name", placeholder="NULL", value=args.seq_name, info="For evaluation") + scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref", "swinstride", "swin2stride"], + value='swinstride', label="Scenegraph", + info="Define how to make pairs", + interactive=True) + winsize = gradio.Slider(label="Scene Graph: Window Size", value=5, + minimum=1, maximum=1, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) + + run_btn = gradio.Button("Run") + + with gradio.Row(): + # adjust the confidence thresholdx + min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.1, minimum=0.0, maximum=20, step=0.01) + # adjust the camera size in the output pointcloud + cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001) + # adjust the temporal smoothing weight + temporal_smoothing_weight = gradio.Slider(label="temporal_smoothing_weight", value=0.01, minimum=0.0, maximum=0.1, step=0.001) + # add translation weight + translation_weight = gradio.Textbox(label="translation_weight", placeholder="1.0", value="1.0", info="For evaluation") + # change to another model + new_model_weights = gradio.Textbox(label="New Model", placeholder=args.weights, value=args.weights, info="Path to updated model weights") + with gradio.Row(): + as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud") + # two post process implemented + mask_sky = gradio.Checkbox(value=False, label="Mask sky") + clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") + transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") + # not to show camera + show_cam = gradio.Checkbox(value=True, label="Show Camera") + shared_focal = gradio.Checkbox(value=True, label="Shared Focal Length") + use_davis_gt_mask = gradio.Checkbox(value=False, label="Use GT Mask (DAVIS)") + with gradio.Row(): + flow_loss_weight = gradio.Slider(label="Flow Loss Weight", value=0.01, minimum=0.0, maximum=1.0, step=0.001) + flow_loss_start_iter = gradio.Slider(label="Flow Loss Start Iter", value=0.1, minimum=0, maximum=1, step=0.01) + flow_loss_threshold = gradio.Slider(label="Flow Loss Threshold", value=25, minimum=0, maximum=100, step=1) + # for video processing + fps = gradio.Slider(label="FPS", value=0, minimum=0, maximum=60, step=1) + num_frames = gradio.Slider(label="Num Frames", value=100, minimum=0, maximum=200, step=1) + + outmodel = gradio.Model3D() + outgallery = gradio.Gallery(label='rgb,depth,confidence, init_conf', columns=4, height="100%") + + # events + scenegraph_type.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + inputfiles.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + run_btn.click(fn=recon_fun, + inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, + mask_sky, clean_depth, transparent_cams, cam_size, show_cam, + scenegraph_type, winsize, refid, seq_name, new_model_weights, + temporal_smoothing_weight, translation_weight, shared_focal, + flow_loss_weight, flow_loss_start_iter, flow_loss_threshold, use_davis_gt_mask, + fps, num_frames], + outputs=[scene, outmodel, outgallery]) + min_conf_thr.release(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + cam_size.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + as_pointcloud.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + mask_sky.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + clean_depth.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + transparent_cams.change(model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + demo.launch(share=args.share, server_name=server_name, server_port=server_port) + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + + if args.output_dir is not None: + tmp_path = args.output_dir + os.makedirs(tmp_path, exist_ok=True) + tempfile.tempdir = tmp_path + + if args.server_name is not None: + server_name = args.server_name + else: + server_name = '0.0.0.0' if args.local_network else '127.0.0.1' + + if args.weights is not None and os.path.exists(args.weights): + weights_path = args.weights + else: + weights_path = args.model_name + + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + # Use the provided output_dir or create a temporary directory + tmpdirname = args.output_dir if args.output_dir is not None else tempfile.mkdtemp(suffix='monst3r_gradio_demo') + + if not args.silent: + print('Outputting stuff in', tmpdirname) + + if args.input_dir is not None: + # Process images in the input directory with default parameters + if os.path.isdir(args.input_dir): # input_dir is a directory of images + input_files = [os.path.join(args.input_dir, fname) for fname in sorted(os.listdir(args.input_dir))] + else: # input_dir is a video + input_files = [args.input_dir] + + if args.real_time: + recon_fun = functools.partial(get_reconstructed_scene_realtime, args, model, args.device, args.silent, args.image_size) + outfile = recon_fun( + filelist=input_files, + scenegraph_type='oneref_mid', + refid=0, + seq_name=args.seq_name, + fps=args.fps, + num_frames=args.num_frames, + ) + else: + recon_fun = functools.partial(get_reconstructed_scene, args, tmpdirname, model, args.device, args.silent, args.image_size) + # Call the function with default parameters + scene, outfile, imgs = recon_fun( + filelist=input_files, + schedule='linear', + niter=300, + min_conf_thr=1.1, + as_pointcloud=True, + mask_sky=False, + clean_depth=True, + transparent_cams=False, + cam_size=0.05, + show_cam=True, + scenegraph_type='swinstride', + winsize=5, + refid=0, + seq_name=args.seq_name, + new_model_weights=args.weights, + temporal_smoothing_weight=0.01, + translation_weight='1.0', + shared_focal=True, + flow_loss_weight=0.01, + flow_loss_start_iter=0.1, + flow_loss_threshold=25, + use_gt_mask=args.use_gt_davis_masks, + fps=args.fps, + num_frames=args.num_frames, + ) + print(f"Processing completed. Output saved in {tmpdirname}/{args.seq_name}") + else: + # Launch Gradio demo + main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent, args=args) diff --git a/monst3r/download_checkpoints.sh b/monst3r/download_checkpoints.sh new file mode 100755 index 0000000..206c170 --- /dev/null +++ b/monst3r/download_checkpoints.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Download all required checkpoints for MonST3R data preparation +# These are third-party model weights that are publicly available + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +echo "=== Downloading MonST3R checkpoint ===" +# MonST3R ViT-Large model (2.2GB) +# Source: https://huggingface.co/Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt +mkdir -p checkpoints +if [ ! -f checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth ]; then + gdown --fuzzy https://drive.google.com/file/d/1Z1jO_JmfZj0z3bgMvCwqfUhyZ1bIbc9E/view?usp=sharing -O checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth + echo "Done." +else + echo "Already exists, skipping." +fi + +echo "=== Downloading SAM2 checkpoint ===" +# SAM2.1 Hiera-Large (857MB) +# Source: https://github.com/facebookresearch/sam2 +mkdir -p third_party/sam2/checkpoints +if [ ! -f third_party/sam2/checkpoints/sam2.1_hiera_large.pt ]; then + wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt -O third_party/sam2/checkpoints/sam2.1_hiera_large.pt + echo "Done." +else + echo "Already exists, skipping." +fi + +echo "=== Downloading RAFT (Sea-RAFT) checkpoint ===" +# Tartan-C-T-TSKH-spring540x960-M (76MB) +mkdir -p third_party/RAFT/models +if [ ! -f third_party/RAFT/models/Tartan-C-T-TSKH-spring540x960-M.pth ]; then + gdown --fuzzy https://drive.google.com/file/d/1a0C5FTdhjM4rKrfXiGhec7eq2YM141lu/view?usp=drive_link -O third_party/RAFT/models/Tartan-C-T-TSKH-spring540x960-M.pth + echo "Done." +else + echo "Already exists, skipping." +fi + +echo "=== Note ===" +echo "ResNet34 weights (resnet34-b627a593.pth) will be automatically" +echo "downloaded by torchvision at runtime. No manual download needed." +echo "" +echo "All checkpoints downloaded successfully!" diff --git a/monst3r/dust3r/__init__.py b/monst3r/dust3r/__init__.py new file mode 100644 index 0000000..a326921 --- /dev/null +++ b/monst3r/dust3r/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/monst3r/dust3r/cloud_opt/__init__.py b/monst3r/dust3r/cloud_opt/__init__.py new file mode 100644 index 0000000..faf5cd2 --- /dev/null +++ b/monst3r/dust3r/cloud_opt/__init__.py @@ -0,0 +1,33 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# global alignment optimization wrapper function +# -------------------------------------------------------- +from enum import Enum + +from .optimizer import PointCloudOptimizer +from .modular_optimizer import ModularPointCloudOptimizer +from .pair_viewer import PairViewer + + +class GlobalAlignerMode(Enum): + PointCloudOptimizer = "PointCloudOptimizer" + ModularPointCloudOptimizer = "ModularPointCloudOptimizer" + PairViewer = "PairViewer" + + +def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): + # extract all inputs + view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] + # build the optimizer + if mode == GlobalAlignerMode.PointCloudOptimizer: + net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: + net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.PairViewer: + net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) + else: + raise NotImplementedError(f'Unknown mode {mode}') + + return net diff --git a/monst3r/dust3r/cloud_opt/base_opt.py b/monst3r/dust3r/cloud_opt/base_opt.py new file mode 100644 index 0000000..062eba5 --- /dev/null +++ b/monst3r/dust3r/cloud_opt/base_opt.py @@ -0,0 +1,563 @@ +# -------------------------------------------------------- +# Base class for the global alignement procedure +# -------------------------------------------------------- +from copy import deepcopy +import cv2 + +import numpy as np +import torch +import torch.nn as nn +import roma +from copy import deepcopy +import tqdm + +from dust3r.utils.geometry import inv, geotrf +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb +from dust3r.viz import SceneViz, segment_sky, auto_cam_size +from dust3r.optim_factory import adjust_learning_rate_by_lr + +from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, + cosine_schedule, linear_schedule, cycled_linear_schedule, get_conf_trf) +import dust3r.cloud_opt.init_im_poses as init_fun +from scipy.spatial.transform import Rotation +from dust3r.utils.vo_eval import save_trajectory_tum_format +import os +import matplotlib.pyplot as plt +from PIL import Image + +def c2w_to_tumpose(c2w): + """ + Convert a camera-to-world matrix to a tuple of translation and rotation + + input: c2w: 4x4 matrix + output: tuple of translation and rotation (x y z qw qx qy qz) + """ + # convert input to numpy + c2w = to_numpy(c2w) + xyz = c2w[:3, -1] + rot = Rotation.from_matrix(c2w[:3, :3]) + qx, qy, qz, qw = rot.as_quat() + tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]]) + return tum_pose + +class BasePCOptimizer (nn.Module): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + other = deepcopy(args[0]) + attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes + min_conf_thr conf_thr conf_i conf_j im_conf + base_scale norm_pw_scale POSE_DIM pw_poses + pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() + self.__dict__.update({k: other[k] for k in attrs}) + else: + self._init_from_views(*args, **kwargs) + + def _init_from_views(self, view1, view2, pred1, pred2, + dist='l1', + conf='log', + min_conf_thr=3, + thr_for_init_conf=False, + base_scale=0.5, + allow_pw_adaptors=False, + pw_break=20, + rand_pose=torch.randn, + empty_cache=False, + verbose=True): + super().__init__() + if not isinstance(view1['idx'], list): + view1['idx'] = view1['idx'].tolist() + if not isinstance(view2['idx'], list): + view2['idx'] = view2['idx'].tolist() + self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} + self.dist = ALL_DISTS[dist] + self.verbose = verbose + self.empty_cache = empty_cache + self.n_imgs = self._check_edges() + + # input data + pred1_pts = pred1['pts3d'] + pred2_pts = pred2['pts3d_in_other_view'] + self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) + self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) + self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) + + # work in log-scale with conf + pred1_conf = pred1['conf'] # (Number of image_pairs, H, W) + pred2_conf = pred2['conf'] # (Number of image_pairs, H, W) + self.min_conf_thr = min_conf_thr + self.thr_for_init_conf = thr_for_init_conf + self.conf_trf = get_conf_trf(conf) + + self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}) + self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}) + self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) + for i in range(len(self.im_conf)): + self.im_conf[i].requires_grad = False + + self.init_conf_maps = [c.clone() for c in self.im_conf] + + # pairwise pose parameters + self.base_scale = base_scale + self.norm_pw_scale = True + self.pw_break = pw_break + self.POSE_DIM = 7 + self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses + self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation + self.pw_adaptors.requires_grad_(allow_pw_adaptors) + self.has_im_poses = False + self.rand_pose = rand_pose + + # possibly store images, camera_pose, instance for show_pointcloud + self.imgs = None + if 'img' in view1 and 'img' in view2: + imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] + for v in range(len(self.edges)): + idx = view1['idx'][v] + imgs[idx] = view1['img'][v] + idx = view2['idx'][v] + imgs[idx] = view2['img'][v] + self.imgs = rgb(imgs) + + self.dynamic_masks = None + if 'dynamic_mask' in view1 and 'dynamic_mask' in view2: + dynamic_masks = [torch.zeros(hw) for hw in self.imshapes] + for v in range(len(self.edges)): + idx = view1['idx'][v] + dynamic_masks[idx] = view1['dynamic_mask'][v] + idx = view2['idx'][v] + dynamic_masks[idx] = view2['dynamic_mask'][v] + self.dynamic_masks = dynamic_masks + + self.camera_poses = None + if 'camera_pose' in view1 and 'camera_pose' in view2: + camera_poses = [torch.zeros((4, 4)) for _ in range(self.n_imgs)] + for v in range(len(self.edges)): + idx = view1['idx'][v] + camera_poses[idx] = view1['camera_pose'][v] + idx = view2['idx'][v] + camera_poses[idx] = view2['camera_pose'][v] + self.camera_poses = camera_poses + + self.img_pathes = None + if 'instance' in view1 and 'instance' in view2: + img_pathes = ['' for _ in range(self.n_imgs)] + for v in range(len(self.edges)): + idx = view1['idx'][v] + img_pathes[idx] = view1['instance'][v] + idx = view2['idx'][v] + img_pathes[idx] = view2['instance'][v] + self.img_pathes = img_pathes + + @property + def n_edges(self): + return len(self.edges) + + @property + def str_edges(self): + return [edge_str(i, j) for i, j in self.edges] + + @property + def imsizes(self): + return [(w, h) for h, w in self.imshapes] + + @property + def device(self): + return next(iter(self.parameters())).device + + def state_dict(self, trainable=True): + all_params = super().state_dict() + return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} + + def load_state_dict(self, data): + return super().load_state_dict(self.state_dict(trainable=False) | data) + + def _check_edges(self): + indices = sorted({i for edge in self.edges for i in edge}) + assert indices == list(range(len(indices))), 'bad pair indices: missing values ' + return len(indices) + + @torch.no_grad() + def _compute_img_conf(self, pred1_conf, pred2_conf): + im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) + for e, (i, j) in enumerate(self.edges): + im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) + im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) + return im_conf + + def get_adaptors(self): + adapt = self.pw_adaptors + adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) + if self.norm_pw_scale: # normalize so that the product == 1 + adapt = adapt - adapt.mean(dim=1, keepdim=True) + return (adapt / self.pw_break).exp() + + def _get_poses(self, poses): + # normalize rotation + Q = poses[:, :4] + T = signed_expm1(poses[:, 4:7]) + RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() + return RT + + def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): + # all poses == cam-to-world + pose = poses[idx] + if not (pose.requires_grad or force): + return pose + + if R.shape == (4, 4): + assert T is None + T = R[:3, 3] + R = R[:3, :3] + + if R is not None: + pose.data[0:4] = roma.rotmat_to_unitquat(R) + if T is not None: + pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale + + if scale is not None: + assert poses.shape[-1] in (8, 13) + pose.data[-1] = np.log(float(scale)) + return pose + + def get_pw_norm_scale_factor(self): + if self.norm_pw_scale: + # normalize scales so that things cannot go south + # we want that exp(scale) ~= self.base_scale + return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() + else: + return 1 # don't norm scale for known poses + + def get_pw_scale(self): + scale = self.pw_poses[:, -1].exp() # (n_edges,) + scale = scale * self.get_pw_norm_scale_factor() + return scale + + def get_pw_poses(self): # cam to world + RT = self._get_poses(self.pw_poses) + scaled_RT = RT.clone() + scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation + return scaled_RT + + def get_masks(self): + if self.thr_for_init_conf: + return [(conf > self.min_conf_thr) for conf in self.init_conf_maps] + else: + return [(conf > self.min_conf_thr) for conf in self.im_conf] + + def depth_to_pts3d(self): + raise NotImplementedError() + + def get_pts3d(self, raw=False, **kwargs): + res = self.depth_to_pts3d(**kwargs) + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def _set_focal(self, idx, focal, force=False): + raise NotImplementedError() + + def get_focals(self): + raise NotImplementedError() + + def get_known_focal_mask(self): + raise NotImplementedError() + + def get_principal_points(self): + raise NotImplementedError() + + def get_conf(self, mode=None): + trf = self.conf_trf if mode is None else get_conf_trf(mode) + return [trf(c) for c in self.im_conf] + + def get_init_conf(self, mode=None): + trf = self.conf_trf if mode is None else get_conf_trf(mode) + return [trf(c) for c in self.init_conf_maps] + + def get_im_poses(self): + raise NotImplementedError() + + def _set_depthmap(self, idx, depth, force=False): + raise NotImplementedError() + + def get_depthmaps(self, raw=False): + raise NotImplementedError() + + def clean_pointcloud(self, **kw): + cams = inv(self.get_im_poses()) + K = self.get_intrinsics() + depthmaps = self.get_depthmaps() + all_pts3d = self.get_pts3d() + + new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw) + + for i, new_conf in enumerate(new_im_confs): + self.im_conf[i].data[:] = new_conf + return self + + def get_tum_poses(self): + poses = self.get_im_poses() + tt = np.arange(len(poses)).astype(float) + tum_poses = [c2w_to_tumpose(p) for p in poses] + tum_poses = np.stack(tum_poses, 0) + return [tum_poses, tt] + + def save_tum_poses(self, path): + traj = self.get_tum_poses() + save_trajectory_tum_format(traj, path) + return traj[0] # return the poses + + def save_focals(self, path): + # convert focal to txt + focals = self.get_focals() + np.savetxt(path, focals.detach().cpu().numpy(), fmt='%.6f') + return focals + + def save_intrinsics(self, path): + K_raw = self.get_intrinsics() + K = K_raw.reshape(-1, 9) + np.savetxt(path, K.detach().cpu().numpy(), fmt='%.6f') + return K_raw + + def save_conf_maps(self, path): + conf = self.get_conf() + for i, c in enumerate(conf): + np.save(f'{path}/conf_{i}.npy', c.detach().cpu().numpy()) + return conf + + def save_init_conf_maps(self, path): + conf = self.get_init_conf() + for i, c in enumerate(conf): + np.save(f'{path}/init_conf_{i}.npy', c.detach().cpu().numpy()) + return conf + + def save_rgb_imgs(self, path): + imgs = self.imgs + for i, img in enumerate(imgs): + # convert from rgb to bgr + img = img[..., ::-1] + cv2.imwrite(f'{path}/frame_{i:04d}.png', img*255) + return imgs + + def save_dynamic_masks(self, path): + dynamic_masks = self.dynamic_masks if getattr(self, 'sam2_dynamic_masks', None) is None else self.sam2_dynamic_masks + for i, dynamic_mask in enumerate(dynamic_masks): + cv2.imwrite(f'{path}/dynamic_mask_{i}.png', (dynamic_mask * 255).detach().cpu().numpy().astype(np.uint8)) + return dynamic_masks + + def save_depth_maps(self, path): + depth_maps = self.get_depthmaps() + images = [] + + for i, depth_map in enumerate(depth_maps): + # Apply color map to depth map + depth_map_colored = cv2.applyColorMap((depth_map * 255).detach().cpu().numpy().astype(np.uint8), cv2.COLORMAP_JET) + img_path = f'{path}/frame_{(i):04d}.png' + cv2.imwrite(img_path, depth_map_colored) + images.append(Image.open(img_path)) + np.save(f'{path}/frame_{(i):04d}.npy', depth_map.detach().cpu().numpy()) + + images[0].save(f'{path}/_depth_maps.gif', save_all=True, append_images=images[1:], duration=100, loop=0) + + return depth_maps + + def forward(self, ret_details=False): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors() + proj_pts3d = self.get_pts3d() + # pre-compute pixel weights + weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} + weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} + + loss = 0 + if ret_details: + details = -torch.ones((self.n_imgs, self.n_imgs)) + + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # distance in image i and j + aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) + aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) + li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() + lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() + loss = loss + li + lj + + if ret_details: + details[i, j] = li + lj + loss /= self.n_edges # average over all pairs + + if ret_details: + return loss, details + return loss + + @torch.cuda.amp.autocast(enabled=False) + def compute_global_alignment(self, init=None, save_score_path=None, save_score_only=False, niter_PnP=10, **kw): + if init is None: + pass + elif init == 'msp' or init == 'mst': + init_fun.init_minimum_spanning_tree(self, save_score_path=save_score_path, save_score_only=save_score_only, niter_PnP=niter_PnP) + if save_score_only: # if only want the score map + return None + elif init == 'known_poses': + self.preset_pose(known_poses=self.camera_poses, requires_grad=True) + init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, + niter_PnP=niter_PnP) + else: + raise ValueError(f'bad value for {init=}') + + return global_alignment_loop(self, **kw) + + @torch.no_grad() + def mask_sky(self): + res = deepcopy(self) + for i in range(self.n_imgs): + sky = segment_sky(self.imgs[i]) + res.im_conf[i][sky] = 0 + return res + + def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): + viz = SceneViz() + if self.imgs is None: + colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) + colors = list(map(tuple, colors.tolist())) + for n in range(self.n_imgs): + viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) + else: + viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) + colors = np.random.randint(256, size=(self.n_imgs, 3)) + + # camera poses + im_poses = to_numpy(self.get_im_poses()) + if cam_size is None: + cam_size = auto_cam_size(im_poses) + viz.add_cameras(im_poses, self.get_focals(), colors=colors, + images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) + if show_pw_cams: + pw_poses = self.get_pw_poses() + viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) + + if show_pw_pts3d: + pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] + viz.add_pointcloud(pts, (128, 0, 128)) + + viz.show(**kw) + return viz + + +def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-3, temporal_smoothing_weight=0, depth_map_save_dir=None): + params = [p for p in net.parameters() if p.requires_grad] + if not params: + return net + + verbose = net.verbose + if verbose: + print('Global alignement - optimizing for:') + print([name for name, value in net.named_parameters() if value.requires_grad]) + + lr_base = lr + optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) + + loss = float('inf') + if verbose: + with tqdm.tqdm(total=niter) as bar: + while bar.n < bar.total: + if bar.n % 500 == 0 and depth_map_save_dir is not None: + if not os.path.exists(depth_map_save_dir): + os.makedirs(depth_map_save_dir) + # visualize the depthmaps + depth_maps = net.get_depthmaps() + for i, depth_map in enumerate(depth_maps): + depth_map_save_path = os.path.join(depth_map_save_dir, f'depthmaps_{i}_iter_{bar.n}.png') + plt.imsave(depth_map_save_path, depth_map.detach().cpu().numpy(), cmap='jet') + print(f"Saved depthmaps at iteration {bar.n} to {depth_map_save_dir}") + loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule, + temporal_smoothing_weight=temporal_smoothing_weight) + bar.set_postfix_str(f'{lr=:g} loss={loss:g}') + bar.update() + else: + for n in range(niter): + loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule, + temporal_smoothing_weight=temporal_smoothing_weight) + return loss + + +def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule, temporal_smoothing_weight=0): + t = cur_iter / niter + if schedule == 'cosine': + lr = cosine_schedule(t, lr_base, lr_min) + elif schedule == 'linear': + lr = linear_schedule(t, lr_base, lr_min) + elif schedule.startswith('cycle'): + try: + num_cycles = int(schedule[5:]) + except ValueError: + num_cycles = 2 + lr = cycled_linear_schedule(t, lr_base, lr_min, num_cycles=num_cycles) + else: + raise ValueError(f'bad lr {schedule=}') + + adjust_learning_rate_by_lr(optimizer, lr) + optimizer.zero_grad() + + if net.empty_cache: + torch.cuda.empty_cache() + + loss = net(epoch=cur_iter) + + if net.empty_cache: + torch.cuda.empty_cache() + + loss.backward() + + if net.empty_cache: + torch.cuda.empty_cache() + + optimizer.step() + + return float(loss), lr + + + +@torch.no_grad() +def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, + tol=0.001, bad_conf=0, dbg=()): + """ Method: + 1) express all 3d points in each camera coordinate frame + 2) if they're in front of a depthmap --> then lower their confidence + """ + assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) + assert 0 <= tol < 1 + res = [c.clone() for c in im_confs] + + # reshape appropriately + all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)] + depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)] + + for i, pts3d in enumerate(all_pts3d): + for j in range(len(all_pts3d)): + if i == j: continue + + # project 3dpts in other view + proj = geotrf(cams[j], pts3d) + proj_depth = proj[:,:,2] + u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) + + # check which points are actually in the visible cone + H, W = im_confs[j].shape + msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) + msk_j = v[msk_i], u[msk_i] + + # find bad points = those in front but less confident + bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j]) + + bad_msk_i = msk_i.clone() + bad_msk_i[msk_i] = bad_points + res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) + + return res diff --git a/monst3r/dust3r/cloud_opt/commons.py b/monst3r/dust3r/cloud_opt/commons.py new file mode 100644 index 0000000..8ed808d --- /dev/null +++ b/monst3r/dust3r/cloud_opt/commons.py @@ -0,0 +1,103 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utility functions for global alignment +# -------------------------------------------------------- +import torch +import torch.nn as nn +import numpy as np +from scipy.stats import zscore + +def edge_str(i, j): + return f'{i}_{j}' + + +def i_j_ij(ij): + # inputs are (i, j) + return edge_str(*ij), ij + + +def edge_conf(conf_i, conf_j, edge): + + score = float(conf_i[edge].mean() * conf_j[edge].mean()) + + return score + + +def compute_edge_scores(edges, conf_i, conf_j): + score_dict = {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} + + return score_dict + +def NoGradParamDict(x): + assert isinstance(x, dict) + return nn.ParameterDict(x).requires_grad_(False) + + +def get_imshapes(edges, pred_i, pred_j): + n_imgs = max(max(e) for e in edges) + 1 + imshapes = [None] * n_imgs + for e, (i, j) in enumerate(edges): + shape_i = tuple(pred_i[e].shape[0:2]) + shape_j = tuple(pred_j[e].shape[0:2]) + if imshapes[i]: + assert imshapes[i] == shape_i, f'incorrect shape for image {i}' + if imshapes[j]: + assert imshapes[j] == shape_j, f'incorrect shape for image {j}' + imshapes[i] = shape_i + imshapes[j] = shape_j + return imshapes + + +def get_conf_trf(mode): + if mode == 'log': + def conf_trf(x): return x.log() + elif mode == 'sqrt': + def conf_trf(x): return x.sqrt() + elif mode == 'm1': + def conf_trf(x): return x-1 + elif mode in ('id', 'none'): + def conf_trf(x): return x + else: + raise ValueError(f'bad mode for {mode=}') + return conf_trf + + +def l2_dist(a, b, weight): + return ((a - b).square().sum(dim=-1) * weight) + + +def l1_dist(a, b, weight): + return ((a - b).norm(dim=-1) * weight) + + +ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) + + +def signed_log1p(x): + sign = torch.sign(x) + return sign * torch.log1p(torch.abs(x)) + + +def signed_expm1(x): + sign = torch.sign(x) + return sign * torch.expm1(torch.abs(x)) + + +def cosine_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 + + +def linear_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_start + (lr_end - lr_start) * t + +def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2): + assert 0 <= t <= 1 + cycle_t = t * num_cycles + cycle_t = cycle_t - int(cycle_t) + if t == 1: + cycle_t = 1 + return linear_schedule(cycle_t, lr_start, lr_end) \ No newline at end of file diff --git a/monst3r/dust3r/cloud_opt/init_im_poses.py b/monst3r/dust3r/cloud_opt/init_im_poses.py new file mode 100644 index 0000000..6f879fa --- /dev/null +++ b/monst3r/dust3r/cloud_opt/init_im_poses.py @@ -0,0 +1,361 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Initialization functions for global alignment +# -------------------------------------------------------- +from functools import cache + +import numpy as np +import scipy.sparse as sp +import torch +import cv2 +import roma +from tqdm import tqdm + +from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses +from dust3r.post_process import estimate_focal_knowing_depth +from dust3r.viz import to_numpy + +from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores +import matplotlib.pyplot as plt +import seaborn as sns + +def draw_edge_scores_map(edge_scores, save_path, n_imgs=None): + # Determine the size of the heatmap + if n_imgs is None: + n_imgs = max(max(edge) for edge in edge_scores) + 1 + + # Create a matrix to hold the scores + heatmap_matrix = np.full((n_imgs, n_imgs), np.nan) + + # Populate the matrix with the edge scores + for (i, j), score in edge_scores.items(): + heatmap_matrix[i, j] = score + + # Plotting the heatmap + plt.figure(figsize=(int(5.5*np.log(n_imgs)-2), int((5.5*np.log(n_imgs)-2) * 3 / 4))) + sns.heatmap(heatmap_matrix, annot=True, fmt=".1f", cmap="viridis", cbar=True, annot_kws={"fontsize": int(-4.2*np.log(n_imgs)+22.4)}) + plt.title("Heatmap of Edge Scores") + plt.xlabel("Node") + plt.ylabel("Node") + plt.savefig(save_path) + +@torch.no_grad() +def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3): + device = self.device + + # indices of known poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + # assert nkp == self.n_imgs, 'not all poses are known' + + # get all focals + nkf, _, im_focals = get_known_focals(self) + # assert nkf == self.n_imgs + im_pp = self.get_principal_points() + + best_depthmaps = {} + # init all pairwise poses + for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)): + i_j = edge_str(i, j) + + # find relative pose for this pair + P1 = torch.eye(4, device=device) + msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1) + _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()), + pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP) + + # align the two predicted camera with the two gt cameras + s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]]) + # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1 + # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # remember if this is a good depthmap + score = float(self.conf_i[i_j].mean()) + if score > best_depthmaps.get(i, (0,))[0]: + best_depthmaps[i] = score, i_j, s + + # init all image poses + for n in range(self.n_imgs): + # assert known_poses_msk[n] + if n in best_depthmaps: + _, i_j, scale = best_depthmaps[n] + depth = self.pred_i[i_j][:, :, 2] + self._set_depthmap(n, depth * scale) + + +@torch.no_grad() +def init_minimum_spanning_tree(self, save_score_path=None, save_score_only=False, **kw): + """ Init all camera poses (image-wise and pairwise poses) given + an initial set of pairwise estimations. + """ + device = self.device + if save_score_only: + eadge_and_scores = compute_edge_scores(map(i_j_ij, self.edges), self.conf_i, self.conf_j) + draw_edge_scores_map(eadge_and_scores, save_score_path) + return + pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges, + self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr, + device, has_im_poses=self.has_im_poses, verbose=self.verbose, save_score_path=save_score_path, + **kw) + + return init_from_pts3d(self, pts3d, im_focals, im_poses) + + +def init_from_pts3d(self, pts3d, im_focals, im_poses): + # init poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + if nkp == 1: + raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose") + elif nkp > 1: + # global rigid SE3 alignment + s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk]) + trf = sRT_to_4x4(s, R, T, device=known_poses.device) + + # rotate everything + im_poses = trf @ im_poses + im_poses[:, :3, :3] /= s # undo scaling on the rotation part + for img_pts3d in pts3d: + img_pts3d[:] = geotrf(trf, img_pts3d) + else: pass # no known poses + + # set all pairwise poses + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # compute transform that goes from cam to world + s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # take into account the scale normalization + s_factor = self.get_pw_norm_scale_factor() + im_poses[:, :3, 3] *= s_factor # apply downscaling factor + for img_pts3d in pts3d: + img_pts3d *= s_factor + + # init all image poses + if self.has_im_poses: + for i in range(self.n_imgs): + cam2world = im_poses[i] + depth = geotrf(inv(cam2world), pts3d[i])[..., 2] + self._set_depthmap(i, depth) + self._set_pose(self.im_poses, i, cam2world) + if im_focals[i] is not None: + if not self.shared_focal: + self._set_focal(i, im_focals[i]) + if self.shared_focal: + self._set_focal(0, sum(im_focals) / self.n_imgs) + if self.n_imgs > 2: + self._set_init_depthmap() + + if self.verbose: + with torch.no_grad(): + print(' init loss =', float(self())) + + +def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr, + device, has_im_poses=True, niter_PnP=10, verbose=True, save_score_path=None): + n_imgs = len(imshapes) + eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j) + sparse_graph = -dict_to_sparse_graph(eadge_and_scores) + msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() + + # temp variable to store 3d points + pts3d = [None] * len(imshapes) + + todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges + im_poses = [None] * n_imgs + im_focals = [None] * n_imgs + + # init with strongest edge + score, i, j = todo.pop() + if verbose: + print(f' init edge ({i}*,{j}*) {score=}') + if save_score_path is not None: + draw_edge_scores_map(eadge_and_scores, save_score_path, n_imgs=n_imgs) + save_tree_path = save_score_path.replace(".png", "_tree.txt") + with open(save_tree_path, "w") as f: + f.write(f'init edge ({i}*,{j}*) {score=}\n') + i_j = edge_str(i, j) + pts3d[i] = pred_i[i_j].clone() # the first one is set to be world coordinate + pts3d[j] = pred_j[i_j].clone() + done = {i, j} + if has_im_poses: + im_poses[i] = torch.eye(4, device=device) + im_focals[i] = estimate_focal(pred_i[i_j]) + + # set initial pointcloud based on pairwise graph + msp_edges = [(i, j)] + while todo: + # each time, predict the next one + score, i, j = todo.pop() + + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[i_j]) + + if i in done: # the first frame is already set, align the second frame with the first frame + if verbose: + print(f' init edge ({i},{j}*) {score=}') + if save_score_path is not None: + with open(save_tree_path, "a") as f: + f.write(f'init edge ({i},{j}*) {score=}\n') + assert j not in done + # align pred[i] with pts3d[i], and then set j accordingly + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[j] = geotrf(trf, pred_j[i_j]) + done.add(j) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + + elif j in done: # the second frame is already set, align the first frame with the second frame + if verbose: + print(f' init edge ({i}*,{j}) {score=}') + if save_score_path is not None: + with open(save_tree_path, "a") as f: + f.write(f'init edge ({i}*,{j}) {score=}\n') + assert i not in done + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[i] = geotrf(trf, pred_i[i_j]) + done.add(i) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + else: + # let's try again later + todo.insert(0, (score, i, j)) + + if has_im_poses: + # complete all missing informations + pair_scores = list(sparse_graph.values()) # already negative scores: less is best + edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)] + for i, j in edges_from_best_to_worse.tolist(): + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[edge_str(i, j)]) + + for i in range(n_imgs): + if im_poses[i] is None: + msk = im_conf[i] > min_conf_thr + res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP) + if res: + im_focals[i], im_poses[i] = res + if im_poses[i] is None: + im_poses[i] = torch.eye(4, device=device) + im_poses = torch.stack(im_poses) + else: + im_poses = im_focals = None + + return pts3d, msp_edges, im_focals, im_poses + + +def dict_to_sparse_graph(dic): + n_imgs = max(max(e) for e in dic) + 1 + res = sp.dok_array((n_imgs, n_imgs)) + for edge, value in dic.items(): + res[edge] = value + return res + + +def rigid_points_registration(pts1, pts2, conf): + R, T, s = roma.rigid_points_registration( + pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True) + return s, R, T # return un-scaled (R, T) + + +def sRT_to_4x4(scale, R, T, device): + trf = torch.eye(4, device=device) + trf[:3, :3] = R * scale + trf[:3, 3] = T.ravel() # doesn't need scaling + return trf + + +def estimate_focal(pts3d_i, pp=None): + if pp is None: + H, W, THREE = pts3d_i.shape + assert THREE == 3 + pp = torch.tensor((W/2, H/2), device=pts3d_i.device) + focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel() + return float(focal) + + +@cache +def pixel_grid(H, W): + return np.mgrid[:W, :H].T.astype(np.float32) + + +def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): + # extract camera poses and focals with RANSAC-PnP + if msk.sum() < 4: + return None # we need at least 4 points for PnP + pts3d, msk = map(to_numpy, (pts3d, msk)) + + H, W, THREE = pts3d.shape + assert THREE == 3 + pixels = pixel_grid(H, W) + + if focal is None: + S = max(W, H) + tentative_focals = np.geomspace(S/2, S*3, 21) + else: + tentative_focals = [focal] + + if pp is None: + pp = (W/2, H/2) + else: + pp = to_numpy(pp) + + best = 0, + for focal in tentative_focals: + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + if not success: + continue + + score = len(inliers) + if success and score > best[0]: + best = score, R, T, focal + + if not best[0]: + return None + + _, R, T, best_focal = best + R = cv2.Rodrigues(R)[0] # world to cam + R, T = map(torch.from_numpy, (R, T)) + return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world + + +def get_known_poses(self): + if self.has_im_poses: + known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses]) + known_poses = self.get_im_poses() + return known_poses_msk.sum(), known_poses_msk, known_poses + else: + return 0, None, None + + +def get_known_focals(self): + if self.has_im_poses: + known_focal_msk = self.get_known_focal_mask() + known_focals = self.get_focals() + return known_focal_msk.sum(), known_focal_msk, known_focals + else: + return 0, None, None + + +def align_multiple_poses(src_poses, target_poses): + N = len(src_poses) + assert src_poses.shape == target_poses.shape == (N, 4, 4) + + def center_and_z(poses): + eps = get_med_dist_between_poses(poses) / 100 + return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2])) + R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True) + return s, R, T diff --git a/monst3r/dust3r/cloud_opt/modular_optimizer.py b/monst3r/dust3r/cloud_opt/modular_optimizer.py new file mode 100644 index 0000000..c59b0c6 --- /dev/null +++ b/monst3r/dust3r/cloud_opt/modular_optimizer.py @@ -0,0 +1,147 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import geotrf +from dust3r.utils.device import to_cpu, to_numpy +from dust3r.utils.geometry import depthmap_to_pts3d +from dust3r.cloud_opt.optimizer import PointCloudOptimizer, tum_to_pose_matrix, ParameterStack, xy_grid + +class ModularPointCloudOptimizer (BasePCOptimizer): + """ Optimize a global scene, given a list of pairwise observations. + Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics) + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs): + super().__init__(*args, **kwargs) + self.has_im_poses = True # by definition of this class + self.focal_brake = focal_brake + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) + self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses + default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes] + self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [ + f]) for f in default_focals) # camera intrinsics + self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + if known_poses.shape[-1] == 7: # xyz wxyz + known_poses = [tum_to_pose_matrix(pose) for pose in known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f' (setting pose #{idx} = {pose[:3,3]})') + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True)) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = (n_known_poses <= 1) + + def preset_intrinsics(self, known_intrinsics, msk=None): + if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2: + known_intrinsics = [known_intrinsics] + for K in known_intrinsics: + assert K.shape == (3, 3) + self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk) + self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk) + + def preset_focal(self, known_focals, msk=None): + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f' (setting focal #{idx} = {focal})') + self._no_grad(self._set_focal(idx, focal, force=True)) + + def preset_principal_point(self, known_pp, msk=None): + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f' (setting principal point #{idx} = {pp})') + self._no_grad(self._set_principal_point(idx, pp, force=True)) + + def _no_grad(self, tensor): + return tensor.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f'bad {msk=}') + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = self.focal_brake * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_brake).exp() + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 + return param + + def get_principal_points(self): + return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)]) + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().view(self.n_imgs, -1) + K[:, 0, 0] = focals[:, 0] + K[:, 1, 1] = focals[:, -1] + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(torch.stack(list(self.im_poses))) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self): + return [d.exp() for d in self.im_depthmaps] + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps() + + # convert focal to (1,2,H,W) constant field + def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i]) + # get pointmaps in camera frame + rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])] + # project to world frame + return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)] + + def get_pts3d(self): + return self.depth_to_pts3d() diff --git a/monst3r/dust3r/cloud_opt/optimizer.py b/monst3r/dust3r/cloud_opt/optimizer.py new file mode 100644 index 0000000..c492359 --- /dev/null +++ b/monst3r/dust3r/cloud_opt/optimizer.py @@ -0,0 +1,782 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +import contextlib + +from dust3r.cloud_opt.base_opt import BasePCOptimizer, edge_str +from dust3r.cloud_opt.pair_viewer import PairViewer +from dust3r.utils.geometry import xy_grid, geotrf, depthmap_to_pts3d +from dust3r.utils.device import to_cpu, to_numpy +from dust3r.utils.goem_opt import DepthBasedWarping, OccMask, WarpImage, depth_regularization_si_weighted, tum_to_pose_matrix +from third_party.raft import load_RAFT +from sam2.build_sam import build_sam2_video_predictor +sam2_checkpoint = "third_party/sam2/checkpoints/sam2.1_hiera_large.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" + +def smooth_L1_loss_fn(estimate, gt, mask, beta=1.0, per_pixel_thre=50.): + loss_raw_shape = F.smooth_l1_loss(estimate*mask, gt*mask, beta=beta, reduction='none') + if per_pixel_thre > 0: + per_pixel_mask = (loss_raw_shape < per_pixel_thre) * mask + else: + per_pixel_mask = mask + return torch.sum(loss_raw_shape * per_pixel_mask) / torch.sum(per_pixel_mask) + +def mse_loss_fn(estimate, gt, mask): + v = torch.sum((estimate*mask-gt*mask)**2) / torch.sum(mask) + return v # , v.item() + +class PointCloudOptimizer(BasePCOptimizer): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, focal_break=20, shared_focal=False, flow_loss_fn='smooth_l1', flow_loss_weight=0.0, + depth_regularize_weight=0.0, num_total_iter=300, temporal_smoothing_weight=0, translation_weight=0.1, flow_loss_start_epoch=0.15, flow_loss_thre=50, + sintel_ckpt=False, use_self_mask=False, pxl_thre=50, sam2_mask_refine=True, motion_mask_thre=0.35, batchify=True, **kwargs): + super().__init__(*args, **kwargs) + + self.has_im_poses = True # by definition of this class + self.focal_break = focal_break + self.num_total_iter = num_total_iter + self.temporal_smoothing_weight = temporal_smoothing_weight + self.translation_weight = translation_weight + self.flow_loss_flag = False + self.flow_loss_start_epoch = flow_loss_start_epoch + self.flow_loss_thre = flow_loss_thre + self.optimize_pp = optimize_pp + self.pxl_thre = pxl_thre + self.motion_mask_thre = motion_mask_thre + self.batchify = batchify + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) + self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses + self.shared_focal = shared_focal + if self.shared_focal: + self.im_focals = nn.ParameterList(torch.FloatTensor( + [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes[:1]) # camera intrinsics + else: + self.im_focals = nn.ParameterList(torch.FloatTensor( + [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics + self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + self.imshape = self.imshapes[0] + im_areas = [h*w for h, w in self.imshapes] + self.max_area = max(im_areas) + + # adding thing to optimize + if self.batchify: + self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area) #(num_imgs, H*W) + self.im_poses = ParameterStack(self.im_poses, is_param=True) + self.im_focals = ParameterStack(self.im_focals, is_param=True) + self.im_pp = ParameterStack(self.im_pp, is_param=True) + self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes])) + self.register_buffer('_grid', ParameterStack( + [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area)) + # pre-compute pixel weights + self.register_buffer('_weight_i', ParameterStack( + [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area)) + self.register_buffer('_weight_j', ParameterStack( + [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area)) + # precompute aa + self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area)) + self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area)) + + self.register_buffer('_ei', torch.tensor([i for i, j in self.edges])) + self.register_buffer('_ej', torch.tensor([j for i, j in self.edges])) + self.total_area_i = sum([im_areas[i] for i, j in self.edges]) + self.total_area_j = sum([im_areas[j] for i, j in self.edges]) + + self.depth_wrapper = DepthBasedWarping() + self.backward_warper = WarpImage() + self.depth_regularizer = depth_regularization_si_weighted + if flow_loss_fn == 'smooth_l1': + self.flow_loss_fn = smooth_L1_loss_fn + elif flow_loss_fn == 'mse': + self.low_loss_fn = mse_loss_fn + + self.flow_loss_weight = flow_loss_weight + self.depth_regularize_weight = depth_regularize_weight + if self.flow_loss_weight > 0: + self.flow_ij, self.flow_ji, self.flow_valid_mask_i, self.flow_valid_mask_j = self.get_flow(sintel_ckpt) # (num_pairs, 2, H, W) + if use_self_mask: self.get_motion_mask_from_pairs(*args) + # turn off the gradient for the flow + self.flow_ij.requires_grad_(False) + self.flow_ji.requires_grad_(False) + self.flow_valid_mask_i.requires_grad_(False) + self.flow_valid_mask_j.requires_grad_(False) + if sam2_mask_refine: + with torch.no_grad(): + self.refine_motion_mask_w_sam2() + else: + self.sam2_dynamic_masks = None + + def get_flow(self, sintel_ckpt=False): #TODO: test with gt flow + print('precomputing flow...') + device = 'cuda' if torch.cuda.is_available() else 'cpu' + get_valid_flow_mask = OccMask(th=3.0) + pair_imgs = [np.stack(self.imgs)[self._ei], np.stack(self.imgs)[self._ej]] + + flow_net = load_RAFT() if sintel_ckpt else load_RAFT("third_party/RAFT/models/Tartan-C-T-TSKH-spring540x960-M.pth") + flow_net = flow_net.to(device) + flow_net.eval() + + with torch.no_grad(): + chunk_size = 12 + flow_ij = [] + flow_ji = [] + num_pairs = len(pair_imgs[0]) + for i in tqdm(range(0, num_pairs, chunk_size)): + end_idx = min(i + chunk_size, num_pairs) + imgs_ij = [torch.tensor(pair_imgs[0][i:end_idx]).float().to(device), + torch.tensor(pair_imgs[1][i:end_idx]).float().to(device)] + flow_ij.append(flow_net(imgs_ij[0].permute(0, 3, 1, 2) * 255, + imgs_ij[1].permute(0, 3, 1, 2) * 255, + iters=20, test_mode=True)[1]) + flow_ji.append(flow_net(imgs_ij[1].permute(0, 3, 1, 2) * 255, + imgs_ij[0].permute(0, 3, 1, 2) * 255, + iters=20, test_mode=True)[1]) + + flow_ij = torch.cat(flow_ij, dim=0) + flow_ji = torch.cat(flow_ji, dim=0) + valid_mask_i = get_valid_flow_mask(flow_ij, flow_ji) + valid_mask_j = get_valid_flow_mask(flow_ji, flow_ij) + print('flow precomputed') + # delete the flow net + if flow_net is not None: del flow_net + return flow_ij, flow_ji, valid_mask_i, valid_mask_j + + def get_motion_mask_from_pairs(self, view1, view2, pred1, pred2): + assert self.is_symmetrized, 'only support symmetric case' + symmetry_pairs_idx = [(i, i+len(self.edges)//2) for i in range(len(self.edges)//2)] + intrinsics_i = [] + intrinsics_j = [] + R_i = [] + R_j = [] + T_i = [] + T_j = [] + depth_maps_i = [] + depth_maps_j = [] + for i, j in tqdm(symmetry_pairs_idx): + new_view1 = {} + new_view2 = {} + for key in view1.keys(): + if isinstance(view1[key], list): + new_view1[key] = [view1[key][i], view1[key][j]] + new_view2[key] = [view2[key][i], view2[key][j]] + elif isinstance(view1[key], torch.Tensor): + new_view1[key] = torch.stack([view1[key][i], view1[key][j]]) + new_view2[key] = torch.stack([view2[key][i], view2[key][j]]) + new_view1['idx'] = [0, 1] + new_view2['idx'] = [1, 0] + new_pred1 = {} + new_pred2 = {} + for key in pred1.keys(): + if isinstance(pred1[key], list): + new_pred1[key] = [pred1[key][i], pred1[key][j]] + elif isinstance(pred1[key], torch.Tensor): + new_pred1[key] = torch.stack([pred1[key][i], pred1[key][j]]) + for key in pred2.keys(): + if isinstance(pred2[key], list): + new_pred2[key] = [pred2[key][i], pred2[key][j]] + elif isinstance(pred2[key], torch.Tensor): + new_pred2[key] = torch.stack([pred2[key][i], pred2[key][j]]) + pair_viewer = PairViewer(new_view1, new_view2, new_pred1, new_pred2, verbose=False) + intrinsics_i.append(pair_viewer.get_intrinsics()[0]) + intrinsics_j.append(pair_viewer.get_intrinsics()[1]) + R_i.append(pair_viewer.get_im_poses()[0][:3, :3]) + R_j.append(pair_viewer.get_im_poses()[1][:3, :3]) + T_i.append(pair_viewer.get_im_poses()[0][:3, 3:]) + T_j.append(pair_viewer.get_im_poses()[1][:3, 3:]) + depth_maps_i.append(pair_viewer.get_depthmaps()[0]) + depth_maps_j.append(pair_viewer.get_depthmaps()[1]) + + self.intrinsics_i = torch.stack(intrinsics_i).to(self.flow_ij.device) + self.intrinsics_j = torch.stack(intrinsics_j).to(self.flow_ij.device) + self.R_i = torch.stack(R_i).to(self.flow_ij.device) + self.R_j = torch.stack(R_j).to(self.flow_ij.device) + self.T_i = torch.stack(T_i).to(self.flow_ij.device) + self.T_j = torch.stack(T_j).to(self.flow_ij.device) + self.depth_maps_i = torch.stack(depth_maps_i).unsqueeze(1).to(self.flow_ij.device) + self.depth_maps_j = torch.stack(depth_maps_j).unsqueeze(1).to(self.flow_ij.device) + + ego_flow_1_2, _ = self.depth_wrapper(self.R_i, self.T_i, self.R_j, self.T_j, 1 / (self.depth_maps_i + 1e-6), self.intrinsics_j, torch.linalg.inv(self.intrinsics_i)) + ego_flow_2_1, _ = self.depth_wrapper(self.R_j, self.T_j, self.R_i, self.T_i, 1 / (self.depth_maps_j + 1e-6), self.intrinsics_i, torch.linalg.inv(self.intrinsics_j)) + + err_map_i = torch.norm(ego_flow_1_2[:, :2, ...] - self.flow_ij[:len(symmetry_pairs_idx)], dim=1) + err_map_j = torch.norm(ego_flow_2_1[:, :2, ...] - self.flow_ji[:len(symmetry_pairs_idx)], dim=1) + # normalize the error map for each pair + err_map_i = (err_map_i - err_map_i.amin(dim=(1, 2), keepdim=True)) / (err_map_i.amax(dim=(1, 2), keepdim=True) - err_map_i.amin(dim=(1, 2), keepdim=True)) + err_map_j = (err_map_j - err_map_j.amin(dim=(1, 2), keepdim=True)) / (err_map_j.amax(dim=(1, 2), keepdim=True) - err_map_j.amin(dim=(1, 2), keepdim=True)) + self.dynamic_masks = [[] for _ in range(self.n_imgs)] + + for i, j in symmetry_pairs_idx: + i_idx = self._ei[i] + j_idx = self._ej[i] + self.dynamic_masks[i_idx].append(err_map_i[i]) + self.dynamic_masks[j_idx].append(err_map_j[i]) + + for i in range(self.n_imgs): + self.dynamic_masks[i] = torch.stack(self.dynamic_masks[i]).mean(dim=0) > self.motion_mask_thre + + def refine_motion_mask_w_sam2(self): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Save previous TF32 settings + if device == 'cuda': + prev_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 + prev_allow_cudnn_tf32 = torch.backends.cudnn.allow_tf32 + # Enable TF32 for Ampere GPUs + if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + try: + autocast_dtype = torch.bfloat16 if device == 'cuda' else torch.float32 + with torch.autocast(device_type=device, dtype=autocast_dtype): + predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device) + frame_tensors = torch.from_numpy(np.array((self.imgs))).permute(0, 3, 1, 2).to(device) + inference_state = predictor.init_state(video_path=frame_tensors) + mask_list = [self.dynamic_masks[i] for i in range(self.n_imgs)] + + ann_obj_id = 1 + self.sam2_dynamic_masks = [[] for _ in range(self.n_imgs)] + + # Process even frames + predictor.reset_state(inference_state) + for idx, mask in enumerate(mask_list): + if idx % 2 == 1: + _, out_obj_ids, out_mask_logits = predictor.add_new_mask( + inference_state, + frame_idx=idx, + obj_id=ann_obj_id, + mask=mask, + ) + video_segments = {} + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, start_frame_idx=0): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + for out_frame_idx in range(self.n_imgs): + if out_frame_idx % 2 == 0: + self.sam2_dynamic_masks[out_frame_idx] = video_segments[out_frame_idx][ann_obj_id] + + # Process odd frames + predictor.reset_state(inference_state) + for idx, mask in enumerate(mask_list): + if idx % 2 == 0: + _, out_obj_ids, out_mask_logits = predictor.add_new_mask( + inference_state, + frame_idx=idx, + obj_id=ann_obj_id, + mask=mask, + ) + video_segments = {} + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, start_frame_idx=0): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + for out_frame_idx in range(self.n_imgs): + if out_frame_idx % 2 == 1: + self.sam2_dynamic_masks[out_frame_idx] = video_segments[out_frame_idx][ann_obj_id] + + # Update dynamic masks + for i in range(self.n_imgs): + self.sam2_dynamic_masks[i] = torch.from_numpy(self.sam2_dynamic_masks[i][0]).to(device) + self.dynamic_masks[i] = self.dynamic_masks[i].to(device) + self.dynamic_masks[i] = self.dynamic_masks[i] | self.sam2_dynamic_masks[i] + + # Clean up + del predictor + finally: + # Restore previous TF32 settings + if device == 'cuda': + torch.backends.cuda.matmul.allow_tf32 = prev_allow_tf32 + torch.backends.cudnn.allow_tf32 = prev_allow_cudnn_tf32 + + + def _check_all_imgs_are_selected(self, msk): + self.msk = torch.from_numpy(np.array(msk, dtype=bool)).to(self.device) + assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!' + pass + + def preset_pose(self, known_poses, pose_msk=None, requires_grad=False): # cam-to-world + self._check_all_imgs_are_selected(pose_msk) + + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + if known_poses.shape[-1] == 7: # xyz wxyz + known_poses = [tum_to_pose_matrix(pose) for pose in known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f' (setting pose #{idx} = {pose[:3,3]})') + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = (n_known_poses <= 1) + if len(known_poses) == self.n_imgs: + if requires_grad: + self.im_poses.requires_grad_(True) + else: + self.im_poses.requires_grad_(False) + self.norm_pw_scale = False + + def preset_intrinsics(self, known_intrinsics, msk=None): + if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2: + known_intrinsics = [known_intrinsics] + for K in known_intrinsics: + assert K.shape == (3, 3) + self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk) + if self.optimize_pp: + self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk) + + def preset_focal(self, known_focals, msk=None, requires_grad=False): + self._check_all_imgs_are_selected(msk) + + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f' (setting focal #{idx} = {focal})') + self._no_grad(self._set_focal(idx, focal)) + if len(known_focals) == self.n_imgs: + if requires_grad: + self.im_focals.requires_grad_(True) + else: + self.im_focals.requires_grad_(False) + + def preset_principal_point(self, known_pp, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f' (setting principal point #{idx} = {pp})') + self._no_grad(self._set_principal_point(idx, pp)) + + self.im_pp.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f'bad {msk=}') + + def _no_grad(self, tensor): + assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs' + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = self.focal_break * np.log(focal) + return param + + def get_focals(self): + if self.shared_focal: + log_focals = torch.stack([self.im_focals[0]] * self.n_imgs, dim=0) + else: + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_break).exp() + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.im_focals]) + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 + return param + + def get_principal_points_non_batch(self): + return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)]) + + def get_principal_points_batch(self): + return self._pp + 10 * self.im_pp + + def get_principal_points(self): + if self.batchify: + return self.get_principal_points_batch() + else: + return self.get_principal_points_non_batch() + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().flatten() + K[:, 0, 0] = K[:, 1, 1] = focals + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses_batch(self): # cam to world + cam2world = self._get_poses(self.im_poses) + return cam2world + + def get_im_poses_non_batch(self): # cam to world + cam2world = self._get_poses(torch.stack(list(self.im_poses))) + return cam2world + + def get_im_poses(self): + if self.batchify: + return self.get_im_poses_batch() + else: + return self.get_im_poses_non_batch() + + def _set_depthmap_batch(self, idx, depth, force=False): + depth = _ravel_hw(depth, self.max_area) + + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def _set_depthmap_non_batch(self, idx, depth, force=False): + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def _set_depthmap(self, idx, depth, force=False): + if self.batchify: + return self._set_depthmap_batch(idx, depth, force) + else: + return self._set_depthmap_non_batch(idx, depth, force) + + def preset_depthmap(self, known_depthmaps, msk=None, requires_grad=False): + self._check_all_imgs_are_selected(msk) + + for idx, depth in zip(self._get_msk_indices(msk), known_depthmaps): + if self.verbose: + print(f' (setting depthmap #{idx})') + self._no_grad(self._set_depthmap(idx, depth)) + + if len(known_depthmaps) == self.n_imgs: + if requires_grad: + self.im_depthmaps.requires_grad_(True) + else: + self.im_depthmaps.requires_grad_(False) + + def _set_init_depthmap(self): + depth_maps = self.get_depthmaps(raw=True) + self.init_depthmap = [dm.detach().clone() for dm in depth_maps] + + def get_init_depthmaps(self, raw=False): + res = self.init_depthmap + if not raw: + res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def get_depthmaps_batch(self, raw=False): + res = self.im_depthmaps.exp() + if not raw: + res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def get_depthmaps_non_batch(self): + return [d.exp() for d in self.im_depthmaps] + + def get_depthmaps(self, raw=False): + if self.batchify: + return self.get_depthmaps_batch(raw) + else: + return self.get_depthmaps_non_batch() + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps(raw=True) + + # get pointmaps in camera frame + rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) + # project to world frame + return geotrf(im_poses, rel_ptmaps) + + def depth_to_pts3d_partial(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps() + + # convert focal to (1,2,H,W) constant field + def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i]) + # get pointmaps in camera frame + rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])] + # project to world frame + return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)] + + def get_pts3d_batch(self, raw=False, **kwargs): + res = self.depth_to_pts3d() + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def get_pts3d(self, raw=False, **kwargs): + if self.batchify: + return self.get_pts3d_batch(raw, **kwargs) + else: + return self.depth_to_pts3d_partial() + + def forward_batchify(self, epoch=9999): + pw_poses = self.get_pw_poses() # cam-to-world + + pw_adapt = self.get_adaptors().unsqueeze(1) + proj_pts3d = self.get_pts3d(raw=True) + + # rotate pairwise prediction according to pw_poses + aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) + aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) + + # compute the less + li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i + lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j + + # camera temporal loss + if self.temporal_smoothing_weight > 0: + temporal_smoothing_loss = self.relative_pose_loss(self.get_im_poses()[:-1], self.get_im_poses()[1:]).sum() + else: + temporal_smoothing_loss = 0 + + if self.flow_loss_weight > 0 and epoch >= self.num_total_iter * self.flow_loss_start_epoch: # enable flow loss after certain epoch + R_all, T_all = self.get_im_poses()[:,:3].split([3, 1], dim=-1) + R1, T1 = R_all[self._ei], T_all[self._ei] + R2, T2 = R_all[self._ej], T_all[self._ej] + K_all = self.get_intrinsics() + inv_K_all = torch.linalg.inv(K_all) + K_1, inv_K_1 = K_all[self._ei], inv_K_all[self._ei] + K_2, inv_K_2 = K_all[self._ej], inv_K_all[self._ej] + depth_all = torch.stack(self.get_depthmaps(raw=False)).unsqueeze(1) + depth1, depth2 = depth_all[self._ei], depth_all[self._ej] + disp_1, disp_2 = 1 / (depth1 + 1e-6), 1 / (depth2 + 1e-6) + ego_flow_1_2, _ = self.depth_wrapper(R1, T1, R2, T2, disp_1, K_2, inv_K_1) + ego_flow_2_1, _ = self.depth_wrapper(R2, T2, R1, T1, disp_2, K_1, inv_K_2) + dynamic_masks_all = torch.stack(self.dynamic_masks).to(self.device).unsqueeze(1) + dynamic_mask1, dynamic_mask2 = dynamic_masks_all[self._ei], dynamic_masks_all[self._ej] + + flow_loss_i = self.flow_loss_fn(ego_flow_1_2[:, :2, ...], self.flow_ij, ~dynamic_mask1, per_pixel_thre=self.pxl_thre) + flow_loss_j = self.flow_loss_fn(ego_flow_2_1[:, :2, ...], self.flow_ji, ~dynamic_mask2, per_pixel_thre=self.pxl_thre) + flow_loss = flow_loss_i + flow_loss_j + print(f'flow loss: {flow_loss.item()}') + if flow_loss.item() > self.flow_loss_thre and self.flow_loss_thre > 0: + flow_loss = 0 + self.flow_loss_flag = True + else: + flow_loss = 0 + + if self.depth_regularize_weight > 0: + init_depthmaps = torch.stack(self.get_init_depthmaps(raw=False)).unsqueeze(1) + depthmaps = torch.stack(self.get_depthmaps(raw=False)).unsqueeze(1) + dynamic_masks_all = torch.stack(self.dynamic_masks).to(self.device).unsqueeze(1) + depth_prior_loss = self.depth_regularizer(depthmaps, init_depthmaps, dynamic_masks_all) + else: + depth_prior_loss = 0 + + loss = (li + lj) * 1 + self.temporal_smoothing_weight * temporal_smoothing_loss + \ + self.flow_loss_weight * flow_loss + self.depth_regularize_weight * depth_prior_loss + + return loss + + def forward_non_batchify(self, epoch=9999): + + # --(1) Perform the original pairwise 3D consistency loss (pairwise 3D consistency)-- + pw_poses = self.get_pw_poses() # pair-wise poses (or adaptive poses) + pw_adapt = self.get_adaptors() + proj_pts3d = self.get_pts3d() # 3D point clouds for each image + weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} + weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} + + loss = 0.0 + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # Transform the pairwise predictions to the world coordinate system + aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) + aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) + # Compute the distance loss between the projected point clouds and the predictions + li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() + lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() + loss += (li + lj) + + # Average the loss + loss /= self.n_edges + + # --(2) Add temporal smoothing constraint between adjacent frames (temporal smoothing)-- + temporal_smoothing_loss = 0.0 + if self.temporal_smoothing_weight > 0: + # Get the global poses (4x4) for all images + im_poses = self.get_im_poses() # shape: (n_imgs, 4, 4) + # Stack the relative poses between adjacent frames and use the existing relative_pose_loss function + rel_RT1, rel_RT2 = [], [] + for idx in range(self.n_imgs - 1): + rel_RT1.append(im_poses[idx]) + rel_RT2.append(im_poses[idx + 1]) + if len(rel_RT1) > 0: + rel_RT1 = torch.stack(rel_RT1, dim=0) # shape: (n_imgs-1, 4, 4) + rel_RT2 = torch.stack(rel_RT2, dim=0) + # Compute the pose difference between adjacent frames + temporal_smoothing_loss = self.relative_pose_loss(rel_RT1, rel_RT2).sum() + loss += self.temporal_smoothing_weight * temporal_smoothing_loss + + # --(3) Add flow constraint (flow_loss), similar to forward_batchify-- + flow_loss = 0.0 + if self.flow_loss_weight > 0 and epoch >= self.num_total_iter * self.flow_loss_start_epoch: + # Iterate through each pair of images and compute the depth map and flow comparison + im_poses = self.get_im_poses() # (n_imgs, 4, 4) + K_all = self.get_intrinsics() # (n_imgs, 3, 3) + inv_K_all = torch.linalg.inv(K_all) + depthmaps = self.get_depthmaps(raw=False) # list of depth maps (H, W) + + for e, (i, j) in enumerate(self.edges): + # Get the rotation, translation, and intrinsics for the two frames + R1 = im_poses[i][:3, :3].unsqueeze(0) # shape: (1, 3, 3) + T1 = im_poses[i][:3, 3].unsqueeze(-1).unsqueeze(0) # (1, 3, 1) + R2 = im_poses[j][:3, :3].unsqueeze(0) + T2 = im_poses[j][:3, 3].unsqueeze(-1).unsqueeze(0) + K1 = K_all[i].unsqueeze(0) # (1, 3, 3) + K2 = K_all[j].unsqueeze(0) + inv_K1 = inv_K_all[i].unsqueeze(0) + inv_K2 = inv_K_all[j].unsqueeze(0) + + # Construct disparity: disp = 1/depth + depth1 = depthmaps[i].unsqueeze(0).unsqueeze(1) # (1, 1, H, W) + depth2 = depthmaps[j].unsqueeze(0).unsqueeze(1) + disp_1 = 1.0 / (depth1 + 1e-6) + disp_2 = 1.0 / (depth2 + 1e-6) + + # Compute "ego-motion flow" by projecting using DepthBasedWarping + # Note that DepthBasedWarping expects batch dimension, so add unsqueeze(0) + ego_flow_1_2, _ = self.depth_wrapper(R1, T1, R2, T2, disp_1, K2, inv_K1) + ego_flow_2_1, _ = self.depth_wrapper(R2, T2, R1, T1, disp_2, K1, inv_K2) + + # Get the corresponding dynamic region masks (if any) + dynamic_mask_i = self.dynamic_masks[i] # shape: (H, W) + dynamic_mask_j = self.dynamic_masks[j] + + # When computing flow loss, exclude or ignore dynamic regions + flow_loss_i = self.flow_loss_fn( + ego_flow_1_2[0, :2, ...], # shape: (2, H, W) + self.flow_ij[e], # shape: (2, H, W), i->j + ~dynamic_mask_i, # mask: True = keep, False = ignore + per_pixel_thre=self.pxl_thre + ) + flow_loss_j = self.flow_loss_fn( + ego_flow_2_1[0, :2, ...], + self.flow_ji[e], # j->i + ~dynamic_mask_j, + per_pixel_thre=self.pxl_thre + ) + flow_loss += (flow_loss_i + flow_loss_j) + + # Optional: handle cases where the flow loss is too large (e.g., early stop) + # divide by the number of edges + flow_loss /= self.n_edges + print(f'flow loss: {flow_loss.item()}') + if flow_loss.item() > self.flow_loss_thre and self.flow_loss_thre > 0: + flow_loss = 0.0 + + loss += self.flow_loss_weight * flow_loss + + # --(4) Add depth regularization (depth_prior_loss) to constrain the initial depth-- + if self.depth_regularize_weight > 0: + init_depthmaps = self.get_init_depthmaps(raw=False) # initial depth maps + current_depthmaps = self.get_depthmaps(raw=False) # current optimized depth maps + depth_prior_loss = 0.0 + for i in range(self.n_imgs): + # Apply constraints on static regions (ignore dynamic regions) + # Make sure the shape has the batch dimension (B,1,H,W) + depth_prior_loss += self.depth_regularizer( + current_depthmaps[i].unsqueeze(0).unsqueeze(1), + init_depthmaps[i].unsqueeze(0).unsqueeze(1), + self.dynamic_masks[i].unsqueeze(0).unsqueeze(1) + ) + loss += self.depth_regularize_weight * depth_prior_loss + + return loss + + def forward(self, epoch=9999): + if self.batchify: + return self.forward_batchify(epoch) + else: + return self.forward_non_batchify(epoch) + + def relative_pose_loss(self, RT1, RT2): + relative_RT = torch.matmul(torch.inverse(RT1), RT2) + rotation_diff = relative_RT[:, :3, :3] + translation_diff = relative_RT[:, :3, 3] + + # Frobenius norm for rotation difference + rotation_loss = torch.norm(rotation_diff - (torch.eye(3, device=RT1.device)), dim=(1, 2)) + + # L2 norm for translation difference + translation_loss = torch.norm(translation_diff, dim=1) + + # Combined loss (one can weigh these differently if needed) + pose_loss = rotation_loss + translation_loss * self.translation_weight + return pose_loss + +def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): + pp = pp.unsqueeze(1) + focal = focal.unsqueeze(1) + assert focal.shape == (len(depth), 1, 1) + assert pp.shape == (len(depth), 1, 2) + assert pixel_grid.shape == depth.shape + (2,) + depth = depth.unsqueeze(-1) + return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) + + +def ParameterStack(params, keys=None, is_param=None, fill=0): + if keys is not None: + params = [params[k] for k in keys] + + if fill > 0: + params = [_ravel_hw(p, fill) for p in params] + + requires_grad = params[0].requires_grad + assert all(p.requires_grad == requires_grad for p in params) + + params = torch.stack(list(params)).float().detach() + if is_param or requires_grad: + params = nn.Parameter(params) + params.requires_grad_(requires_grad) + return params + + +def _ravel_hw(tensor, fill=0): + # ravel H,W + tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) + + if len(tensor) < fill: + tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:]))) + return tensor + + +def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + return minf*focal_base, maxf*focal_base + + +def apply_mask(img, msk): + img = img.copy() + img[msk] = 0 + return img + +def ordered_ratio(disp_a, disp_b, mask=None): + ratio_a = torch.maximum(disp_a, disp_b) / \ + (torch.minimum(disp_a, disp_b)+1e-5) + if mask is not None: + ratio_a = ratio_a[mask] + return ratio_a - 1 \ No newline at end of file diff --git a/monst3r/dust3r/cloud_opt/pair_viewer.py b/monst3r/dust3r/cloud_opt/pair_viewer.py new file mode 100644 index 0000000..7307da5 --- /dev/null +++ b/monst3r/dust3r/cloud_opt/pair_viewer.py @@ -0,0 +1,133 @@ +# -------------------------------------------------------- +# Dummy optimizer for visualizing pairs +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn +import cv2 + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates +from dust3r.cloud_opt.commons import edge_str +from dust3r.post_process import estimate_focal_knowing_depth + + +class PairViewer (BasePCOptimizer): + """ + This a Dummy Optimizer. + To use only when the goal is to visualize the results for a pair of images (with is_symmetrized) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.is_symmetrized and self.n_edges == 2 + self.has_im_poses = True + + # compute all parameters directly from raw input + self.focals = [] + self.pp = [] + rel_poses = [] + confs = [] + for i in range(self.n_imgs): + conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean()) + if self.verbose: + print(f' - {conf=:.3} for edge {i}-{1-i}') + confs.append(conf) + + H, W = self.imshapes[i] + pts3d = self.pred_i[edge_str(i, 1-i)] + pp = torch.tensor((W/2, H/2)) + focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld')) + self.focals.append(focal) + self.pp.append(pp) + + # estimate the pose of pts1 in image 2 + pixels = np.mgrid[:W, :H].T.astype(np.float32) + pts3d = self.pred_j[edge_str(1-i, i)].numpy() + assert pts3d.shape[:2] == (H, W) + msk = self.get_masks()[i].numpy() + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + try: + res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + success, R, T, inliers = res + assert success + + R = cv2.Rodrigues(R)[0] # world to cam + pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world + except: + pose = np.eye(4) + rel_poses.append(torch.from_numpy(pose.astype(np.float32))) + + # let's use the pair with the most confidence + if confs[0] > confs[1]: + # ptcloud is expressed in camera1 + self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1 + self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]] + else: + # ptcloud is expressed in camera2 + self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2 + self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]] + + self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False) + self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False) + self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False) + self.depth = nn.ParameterList(self.depth) + for p in self.parameters(): + p.requires_grad = False + + def _set_depthmap(self, idx, depth, force=False): + if self.verbose: + print('_set_depthmap is ignored in PairViewer') + return + + def get_depthmaps(self, raw=False): + depth = [d.to(self.device) for d in self.depth] + return depth + + def _set_focal(self, idx, focal, force=False): + self.focals[idx] = focal + + def get_focals(self): + return self.focals + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.focals]) + + def get_principal_points(self): + return self.pp + + def get_intrinsics(self): + focals = self.get_focals() + pps = self.get_principal_points() + K = torch.zeros((len(focals), 3, 3), device=self.device) + for i in range(len(focals)): + K[i, 0, 0] = K[i, 1, 1] = focals[i] + K[i, :2, 2] = pps[i] + K[i, 2, 2] = 1 + return K + + def get_im_poses(self): + return self.im_poses + + def depth_to_pts3d(self, raw_pts=False): + pts3d = [] + if raw_pts: + im_poses = self.get_im_poses() + if im_poses[0].sum() == 4: + pts3d.append(self.pred_i['0_1']) + pts3d.append(self.pred_j['0_1']) + else: + pts3d.append(self.pred_j['1_0']) + pts3d.append(self.pred_i['1_0']) + else: + for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()): + pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(), + intrinsics.cpu().numpy(), + im_pose.cpu().numpy()) + pts3d.append(torch.from_numpy(pts).to(device=self.device)) + return pts3d + + def forward(self): + return float('nan') diff --git a/monst3r/dust3r/datasets/__init__.py b/monst3r/dust3r/datasets/__init__.py new file mode 100644 index 0000000..f1cf539 --- /dev/null +++ b/monst3r/dust3r/datasets/__init__.py @@ -0,0 +1,56 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +from .utils.transforms import * +from .base.batched_sampler import BatchedRandomSampler # noqa +from .arkitscenes import ARKitScenes # noqa +from .blendedmvs import BlendedMVS # noqa +from .co3d import Co3d # noqa +from .habitat import Habitat # noqa +from .megadepth import MegaDepth # noqa +from .scannetpp import ScanNetpp # noqa +from .staticthings3d import StaticThings3D # noqa +from .waymo import Waymo # noqa +from .wildrgbd import WildRGBD # noqa +from .pointodyssey import PointOdysseyDUSt3R # noqa +from .sintel import SintelDUSt3R # noqa +from .tartanair import TarTanAirDUSt3R # noqa +from .spring_dataset import SpringDUSt3R # noqa +from .dynamic_replica import DynamicReplicaDUSt3R # noqa + +def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): + import torch + from croco.utils.misc import get_world_size, get_rank + + # pytorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + # dataset: "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" + # eval(dataset) returns Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + + world_size = get_world_size() + rank = get_rank() + + try: + sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, + rank=rank, drop_last=drop_last) + except (AttributeError, NotImplementedError): + # not avail for this dataset + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader diff --git a/monst3r/dust3r/datasets/arkitscenes.py b/monst3r/dust3r/datasets/arkitscenes.py new file mode 100644 index 0000000..4fad51a --- /dev/null +++ b/monst3r/dust3r/datasets/arkitscenes.py @@ -0,0 +1,102 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed arkitscenes +# dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license +# See datasets_preprocess/preprocess_arkitscenes.py +# -------------------------------------------------------- +import os.path as osp +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class ARKitScenes(BaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + if split == "train": + self.split = "Training" + elif split == "test": + self.split = "Test" + else: + raise ValueError("") + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + with np.load(osp.join(self.ROOT, split, 'all_metadata.npz')) as data: + self.scenes = data['scenes'] + self.sceneids = data['sceneids'] + self.images = data['images'] + self.intrinsics = data['intrinsics'].astype(np.float32) + self.trajectories = data['trajectories'].astype(np.float32) + self.pairs = data['pairs'][:, :2].astype(int) + + def __len__(self): + return len(self.pairs) + + def _get_views(self, idx, resolution, rng): + + image_idx1, image_idx2 = self.pairs[idx] + + views = [] + for view_idx in [image_idx1, image_idx2]: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, 'vga_wide', basename.replace('.png', '.jpg'))) + # Load depthmap + depthmap = imread_cv2(osp.join(scene_dir, 'lowres_depth', basename), cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='arkitscenes', + label=self.scenes[scene_id] + '_' + basename, + instance=f'{str(idx)}_{str(view_idx)}', + )) + + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = ARKitScenes(split='train', ROOT="data/arkitscenes_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/monst3r/dust3r/datasets/base/__init__.py b/monst3r/dust3r/datasets/base/__init__.py new file mode 100644 index 0000000..a326921 --- /dev/null +++ b/monst3r/dust3r/datasets/base/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/monst3r/dust3r/datasets/base/base_stereo_view_dataset.py b/monst3r/dust3r/datasets/base/base_stereo_view_dataset.py new file mode 100644 index 0000000..ff9489f --- /dev/null +++ b/monst3r/dust3r/datasets/base/base_stereo_view_dataset.py @@ -0,0 +1,229 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# base class for implementing datasets +# -------------------------------------------------------- +import PIL +import numpy as np +import torch + +from dust3r.datasets.base.easy_dataset import EasyDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates +import dust3r.datasets.utils.cropping as cropping + + +class BaseStereoViewDataset (EasyDataset): + """ Define all basic options. + + Usage: + class MyDataset (BaseStereoViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__(self, *, # only keyword arguments + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + aug_focal=False, + z_far=0, + seed=None): + self.num_views = 2 + self.split = split + self._set_resolutions(resolution) + + self.transform = transform + if isinstance(transform, str): + transform = eval(transform) + + self.aug_crop = aug_crop + self.aug_focal = aug_focal + self.seed = seed + self.z_far = z_far + + def __len__(self): + return len(self.scenes) + + def get_stats(self): + return f"{len(self)} pairs" + + def __repr__(self): + resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']' + return f"""{type(self).__name__}({self.get_stats()}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '') + + def _get_views(self, idx, resolution, rng): + raise NotImplementedError() + + def __getitem__(self, idx): + if isinstance(idx, tuple): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, '_rng'): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # over-loaded code + resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng) + assert len(views) == self.num_views + + # check data-types + for v, view in enumerate(views): + assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view['idx'] = (idx, ar_idx, v) + + # encode the image + width, height = view['img'].size + view['true_shape'] = np.int32((height, width)) + view['img'] = self.transform(view['img']) + + assert 'camera_intrinsics' in view + if 'camera_pose' not in view: + view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}' + assert 'pts3d' not in view + assert 'valid_mask' not in view + assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}' + view['z_far'] = self.z_far + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view['pts3d'] = pts3d + view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view['camera_intrinsics'] + + # last thing done! + for view in views: + # transpose to make sure all views are the same size + transpose_to_landscape(view) + # this allows to check whether the RNG is is the same state each time + view['rng'] = int.from_bytes(self._rng.bytes(4), 'big') + return views + + def _set_resolutions(self, resolutions): + assert resolutions is not None, 'undefined resolution' + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int' + assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int' + assert width >= height + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None): + """ This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W-cx) + min_margin_y = min(cy, H-cy) + assert min_margin_x > W/5, f'Bad principal point in view={info}' + assert min_margin_y > H/5, f'Bad principal point in view={info}' + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + # transpose the resolution if necessary + W, H = image.size # new size + assert resolution[0] >= resolution[1] + if H > 1.1*W: + # image is portrait mode + resolution = resolution[::-1] + elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2): + resolution = resolution[::-1] + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + if self.aug_focal: + crop_scale = self.aug_focal + (1.0 - self.aug_focal) * np.random.beta(0.5, 0.5) # beta distribution, bi-modal + image, depthmap, intrinsics = cropping.center_crop_image_depthmap(image, depthmap, intrinsics, crop_scale) + + if self.aug_crop > 1: + target_resolution += rng.integers(0, self.aug_crop) + image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) # slightly scale the image a bit larger than the target resolution + + # actual cropping (if necessary) with bilinear interpolation + intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5) + crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + return image, depthmap, intrinsics2 + + +def is_good_type(key, v): + """ returns (is_good, err_msg) + """ + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (torch.bool, np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x + db = sel(view['dataset']) + label = sel(view['label']) + instance = sel(view['instance']) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view['true_shape'] + + if width < height: + # rectify portrait to landscape + assert view['img'].shape == (3, height, width) + view['img'] = view['img'].swapaxes(1, 2) + + assert view['valid_mask'].shape == (height, width) + view['valid_mask'] = view['valid_mask'].swapaxes(0, 1) + + assert view['depthmap'].shape == (height, width) + view['depthmap'] = view['depthmap'].swapaxes(0, 1) + + assert view['pts3d'].shape == (height, width, 3) + view['pts3d'] = view['pts3d'].swapaxes(0, 1) + + # transpose x and y pixels + view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]] diff --git a/monst3r/dust3r/datasets/base/batched_sampler.py b/monst3r/dust3r/datasets/base/batched_sampler.py new file mode 100644 index 0000000..85f58a6 --- /dev/null +++ b/monst3r/dust3r/datasets/base/batched_sampler.py @@ -0,0 +1,74 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Random sampling under a constraint +# -------------------------------------------------------- +import numpy as np +import torch + + +class BatchedRandomSampler: + """ Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): + self.batch_size = batch_size + self.pool_size = pool_size + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size*world_size) if drop_last else N + assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' + + # distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + return self.total_size // self.world_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + # prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # random feat_idxs (same across each batch) + n_batches = (self.total_size+self.batch_size-1) // self.batch_size + feat_idxs = rng.integers(self.pool_size, size=n_batches) + feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) + feat_idxs = feat_idxs.ravel()[:self.total_size] + + # put them together + idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) + + # Distributed sampler: we select a subset of batches + # make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ((self.total_size + self.world_size * + self.batch_size-1) // (self.world_size * self.batch_size)) + idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple-1 + return (total//multiple) * multiple diff --git a/monst3r/dust3r/datasets/base/easy_dataset.py b/monst3r/dust3r/datasets/base/easy_dataset.py new file mode 100644 index 0000000..4939a88 --- /dev/null +++ b/monst3r/dust3r/datasets/base/easy_dataset.py @@ -0,0 +1,157 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# A dataset base class that you can easily resize and combine. +# -------------------------------------------------------- +import numpy as np +from dust3r.datasets.base.batched_sampler import BatchedRandomSampler + + +class EasyDataset: + """ a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) + + +class MulDataset (EasyDataset): + """ Artifically augmenting the size of a dataset. + """ + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f'{self.multiplicator}*{repr(self.dataset)}' + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[idx // self.multiplicator, other] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class ResizedDataset (EasyDataset): + """ Artifically changing the size of a dataset. + """ + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str)-1) // 3): + sep = -4*i-3 + size_str = size_str[:sep] + '_' + size_str[sep:] + return f'{size_str} @ {repr(self.dataset)}' + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch+777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) + self._idxs_mapping = shuffled_idxs[:self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def __getitem__(self, idx): + assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[self._idxs_mapping[idx], other] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class CatDataset (EasyDataset): + """ Concatenation of several datasets + """ + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other = idx + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, 'right') + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None: + new_idx = (new_idx, other) + return dataset[new_idx] + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions diff --git a/monst3r/dust3r/datasets/blendedmvs.py b/monst3r/dust3r/datasets/blendedmvs.py new file mode 100644 index 0000000..93e68c2 --- /dev/null +++ b/monst3r/dust3r/datasets/blendedmvs.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed BlendedMVS +# dataset at https://github.com/YoYo000/BlendedMVS +# See datasets_preprocess/preprocess_blendedmvs.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class BlendedMVS (BaseStereoViewDataset): + """ Dataset of outdoor street scenes, 5 images each time + """ + + def __init__(self, *args, ROOT, split=None, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self._load_data(split) + + def _load_data(self, split): + pairs = np.load(osp.join(self.ROOT, 'blendedmvs_pairs.npy')) + if split is None: + selection = slice(None) + if split == 'train': + # select 90% of all scenes + selection = (pairs['seq_low'] % 10) > 0 + if split == 'val': + # select 10% of all scenes + selection = (pairs['seq_low'] % 10) == 0 + self.pairs = pairs[selection] + + # list of all scenes + self.scenes = np.unique(self.pairs['seq_low']) # low is unique enough + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.scenes)} scenes' + + def _get_views(self, pair_idx, resolution, rng): + seqh, seql, img1, img2, score = self.pairs[pair_idx] + + seq = f"{seqh:08x}{seql:016x}" + seq_path = osp.join(self.ROOT, seq) + + views = [] + + for view_index in [img1, img2]: + impath = f"{view_index:08n}" + image = imread_cv2(osp.join(seq_path, impath + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) + camera_params = np.load(osp.join(seq_path, impath + ".npz")) + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params['R_cam2world'] + camera_pose[:3, 3] = camera_params['t_cam2world'] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='BlendedMVS', + label=osp.relpath(seq_path, self.ROOT), + instance=impath)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = BlendedMVS(split='train', ROOT="data/blendedmvs_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/monst3r/dust3r/datasets/co3d.py b/monst3r/dust3r/datasets/co3d.py new file mode 100644 index 0000000..20a9ac2 --- /dev/null +++ b/monst3r/dust3r/datasets/co3d.py @@ -0,0 +1,192 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed Co3d_v2 +# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International +# See datasets_preprocess/preprocess_co3d.py +# -------------------------------------------------------- +import sys +sys.path.append('.') +import os.path as osp +import json +import itertools +from collections import deque + +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class Co3d(BaseStereoViewDataset): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert mask_bg in (True, False, 'rand') + self.mask_bg = mask_bg + self.dataset_label = 'Co3d_v2' + + # load all scenes + with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f: + self.scenes = json.load(f) + self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} + self.scenes = {(k, k2): v2 for k, v in self.scenes.items() + for k2, v2 in v.items()} + self.scene_list = list(self.scenes.keys()) + # 00: + # ('apple', '110_13051_23361') + # 01: + # ('apple', '189_20393_38136') + + # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) + # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees + self.combinations = [(i, j) + for i, j in itertools.combinations(range(100), 2) + if 0 < abs(i - j) <= 30 and abs(i - j) % 5 == 0] + + self.invalidate = {scene: {} for scene in self.scene_list} + + def __len__(self): + return len(self.scene_list) * len(self.combinations) + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz') + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png') + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') + + def _read_depthmap(self, depthpath, input_metadata): + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth']) + return depthmap + + def _get_views(self, idx, resolution, rng): + # choose a scene + obj, instance = self.scene_list[idx // len(self.combinations)] + image_pool = self.scenes[obj, instance] + im1_idx, im2_idx = self.combinations[idx % len(self.combinations)] + + # add a bit of randomness + last = len(image_pool) - 1 + + if resolution not in self.invalidate[obj, instance]: # flag invalid images + self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] + + # decide now if we mask the bg + mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) + + views = [] + imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]] + imgs_idxs = deque(imgs_idxs) + while len(imgs_idxs) > 0: # some images (few) have zero depth + im_idx = imgs_idxs.pop() + + if self.invalidate[obj, instance][resolution][im_idx]: + # search for a valid image + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(image_pool)): + tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) + if not self.invalidate[obj, instance][resolution][tentative_im_idx]: + im_idx = tentative_im_idx + break + + view_idx = image_pool[im_idx] + + impath = self._get_impath(obj, instance, view_idx) + depthpath = self._get_depthpath(obj, instance, view_idx) + + # load camera params + metadata_path = self._get_metadatapath(obj, instance, view_idx) + input_metadata = np.load(metadata_path) + camera_pose = input_metadata['camera_pose'].astype(np.float32) + intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = self._read_depthmap(depthpath, input_metadata) + + if mask_bg: + # load object mask + maskpath = self._get_maskpath(obj, instance, view_idx) + maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) + maskmap = (maskmap / 255.0) > 0.1 + + # update the depthmap with mask + depthmap *= maskmap + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) + + num_valid = (depthmap > 0.0).sum() + if num_valid == 0: + # problem, invalidate image and retry + self.invalidate[obj, instance][resolution][im_idx] = True + imgs_idxs.append(im_idx) + continue + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + import gradio as gr + import random + + def visualize_scene(idx): + views = dataset[idx] + assert len(views) == 2 + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(255, 0, 0), + image=colors, + cam_size=cam_size) + path = f"./tmp/scene_{idx}.glb" + return viz.save_glb(path) + + # Load the dataset + dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16) + + use_gradio = False + if use_gradio: + # Create Gradio interface + iface = gr.Interface( + fn=visualize_scene, + inputs=gr.Slider(0, len(dataset)-1, step=1, label="Index"), + outputs=gr.Model3D(), # Use Model3D output type for 3D models + live=True + ) + + # Launch the interface + iface.launch() + else: + # sample an idx + idx = random.randint(0, len(dataset)-1) + visualize_scene(idx) + print(f"Visualizing scene {idx}...") \ No newline at end of file diff --git a/monst3r/dust3r/datasets/dynamic_replica.py b/monst3r/dust3r/datasets/dynamic_replica.py new file mode 100644 index 0000000..8d6af64 --- /dev/null +++ b/monst3r/dust3r/datasets/dynamic_replica.py @@ -0,0 +1,300 @@ +import sys +sys.path.append('.') +import os +import torch +import numpy as np +import os.path as osp +import torchvision.transforms as transforms +import torch.nn.functional as F +from PIL import Image +from torch._C import dtype, set_flush_denormal +import dust3r.utils.po_utils.basic +import dust3r.utils.po_utils.improc +from dust3r.utils.po_utils.misc import farthest_point_sample_py +from dust3r.utils.po_utils.geom import apply_4x4_py, apply_pix_T_cam_py +import glob +import cv2 +from torchvision.transforms import ColorJitter, GaussianBlur +from functools import partial +import json + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 +from dust3r.utils.misc import get_stride_distribution + +np.random.seed(125) +torch.multiprocessing.set_sharing_strategy('file_system') + + +def convert_ndc_to_pixel_intrinsics( + focal_length_ndc, principal_point_ndc, image_width, image_height, intrinsics_format='ndc_isotropic' +): + f_x_ndc, f_y_ndc = focal_length_ndc + c_x_ndc, c_y_ndc = principal_point_ndc + + # Compute half image size + half_image_size_wh_orig = np.array([image_width, image_height]) / 2.0 + + # Determine rescale factor based on intrinsics_format + if intrinsics_format.lower() == "ndc_norm_image_bounds": + rescale = half_image_size_wh_orig # [image_width/2, image_height/2] + elif intrinsics_format.lower() == "ndc_isotropic": + rescale = np.min(half_image_size_wh_orig) # scalar value + else: + raise ValueError(f"Unknown intrinsics format: {intrinsics_format}") + + # Convert focal length from NDC to pixel coordinates + if intrinsics_format.lower() == "ndc_norm_image_bounds": + focal_length_px = np.array([f_x_ndc, f_y_ndc]) * rescale + elif intrinsics_format.lower() == "ndc_isotropic": + focal_length_px = np.array([f_x_ndc, f_y_ndc]) * rescale + + # Convert principal point from NDC to pixel coordinates + principal_point_px = half_image_size_wh_orig - np.array([c_x_ndc, c_y_ndc]) * rescale + + # Construct the intrinsics matrix in pixel coordinates + K_pixel = np.array([ + [focal_length_px[0], 0, principal_point_px[0]], + [0, focal_length_px[1], principal_point_px[1]], + [0, 0, 1] + ]) + + return K_pixel + +def load_16big_png_depth(depth_png): + with Image.open(depth_png) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + depth = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0])) + ) + return depth + +class DynamicReplicaDUSt3R(BaseStereoViewDataset): + def __init__(self, + dataset_location='data/dynamic_replica', + use_augs=False, + S=2, + strides=[1,2,3,4,5,6,7,8,9], + clip_step=2, + quick=False, + verbose=False, + dist_type=None, + clip_step_last_skip = 0, + *args, + **kwargs + ): + + print('loading pointodyssey dataset...') + super().__init__(*args, **kwargs) + self.dataset_label = 'pointodyssey' + self.S = S # stride + self.verbose = verbose + + self.use_augs = use_augs + + self.rgb_paths = [] + self.depth_paths = [] + self.normal_paths = [] + self.traj_paths = [] + self.annotation_paths = [] + self.full_idxs = [] + self.sample_stride = [] + self.strides = strides + + self.subdirs = [] + self.sequences = [] + self.subdirs.append(os.path.join(dataset_location)) + + anno_path = os.path.join(dataset_location, 'frame_annotations_train.json') + with open(anno_path, 'r') as f: + self.anno = json.load(f) + + #organize anno by 'sequence_name' + anno_by_seq = {} + for a in self.anno: + seq_name = a['sequence_name'] + if seq_name not in anno_by_seq: + anno_by_seq[seq_name] = [] + anno_by_seq[seq_name].append(a) + + for subdir in self.subdirs: + for seq in glob.glob(os.path.join(subdir, "*/")): + seq_name = seq.split('/')[-1] + self.sequences.append(seq) + + self.sequences = anno_by_seq.keys() + if self.verbose: + print(self.sequences) + print('found %d unique videos in %s ' % (len(self.sequences), dataset_location)) + + + if quick: + self.sequences = self.sequences[1:2] + + for seq in self.sequences: + if self.verbose: + print('seq', seq) + + anno = anno_by_seq[seq] + + + for stride in strides: + for ii in range(0,len(anno)-self.S*max(stride,clip_step_last_skip)+1, clip_step): + full_idx = ii + np.arange(self.S)*stride + self.rgb_paths.append([os.path.join(dataset_location, anno[idx]['image']['path']) for idx in full_idx]) + self.depth_paths.append([os.path.join(dataset_location, anno[idx]['depth']['path']) for idx in full_idx]) + # check if all paths are valid, if not, skip + if not all([os.path.exists(p) for p in self.rgb_paths[-1]]) or not all([os.path.exists(p) for p in self.depth_paths[-1]]): + self.rgb_paths.pop() + self.depth_paths.pop() + continue + self.annotation_paths.append([anno[idx]['viewpoint'] for idx in full_idx]) + self.full_idxs.append(full_idx) + self.sample_stride.append(stride) + if self.verbose: + sys.stdout.write('.') + sys.stdout.flush() + + + self.stride_counts = {} + self.stride_idxs = {} + for stride in strides: + self.stride_counts[stride] = 0 + self.stride_idxs[stride] = [] + for i, stride in enumerate(self.sample_stride): + self.stride_counts[stride] += 1 + self.stride_idxs[stride].append(i) + print('stride counts:', self.stride_counts) + + if len(strides) > 1 and dist_type is not None: + self._resample_clips(strides, dist_type) + + print('collected %d clips of length %d in %s' % ( + len(self.rgb_paths), self.S, dataset_location,)) + + def _resample_clips(self, strides, dist_type): + + # Get distribution of strides, and sample based on that + dist = get_stride_distribution(strides, dist_type=dist_type) + dist = dist / np.max(dist) + max_num_clips = self.stride_counts[strides[np.argmax(dist)]] + num_clips_each_stride = [min(self.stride_counts[stride], int(dist[i]*max_num_clips)) for i, stride in enumerate(strides)] + print('resampled_num_clips_each_stride:', num_clips_each_stride) + resampled_idxs = [] + for i, stride in enumerate(strides): + resampled_idxs += np.random.choice(self.stride_idxs[stride], num_clips_each_stride[i], replace=False).tolist() + + self.rgb_paths = [self.rgb_paths[i] for i in resampled_idxs] + self.depth_paths = [self.depth_paths[i] for i in resampled_idxs] + self.annotation_paths = [self.annotation_paths[i] for i in resampled_idxs] + self.full_idxs = [self.full_idxs[i] for i in resampled_idxs] + self.sample_stride = [self.sample_stride[i] for i in resampled_idxs] + + def __len__(self): + return len(self.rgb_paths) + + def _get_views(self, index, resolution, rng): + + rgb_paths = self.rgb_paths[index] + depth_paths = self.depth_paths[index] + full_idx = self.full_idxs[index] + annotations = self.annotation_paths[index] + focals = [np.array(annotations[i]['focal_length']).astype(np.float32) for i in range(2)] + pp = [np.array(annotations[i]['principal_point']).astype(np.float32) for i in range(2)] + intrinsics_format = [annotations[i]['intrinsics_format'] for i in range(2)] + cams_T_world_R = [np.array(annotations[i]['R']).astype(np.float32) for i in range(2)] + cams_T_world_t = [np.array(annotations[i]['T']).astype(np.float32) for i in range(2)] + + views = [] + for i in range(2): + + impath = rgb_paths[i] + depthpath = depth_paths[i] + + # load camera params + R = cams_T_world_R[i] + t = cams_T_world_t[i] + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3,:3] = R.T + camera_pose[:3,3] = -R.T @ t + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = load_16big_png_depth(depthpath) + + # load intrinsics + intrinsics = convert_ndc_to_pixel_intrinsics(focals[i], pp[i], rgb_image.shape[1], rgb_image.shape[0], + intrinsics_format=intrinsics_format[i]) + intrinsics = intrinsics.astype(np.float32) + + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=rgb_paths[i].split('/')[-3], + instance=osp.split(rgb_paths[i])[1], + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + import gradio as gr + import random + + use_augs = False + S = 2 + strides = [1,2,3,4,5,6,7,8,9] + clip_step = 2 + quick = False # Set to True for quick testing + + def visualize_scene(idx): + views = dataset[idx] + assert len(views) == 2 + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.25) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(255, 0, 0), + image=colors, + cam_size=cam_size) + os.makedirs('./tmp/replica', exist_ok=True) + path = f"./tmp/replica/replica_scene_{idx}.glb" + return viz.save_glb(path) + + dataset = DynamicReplicaDUSt3R( + use_augs=use_augs, + S=S, + strides=strides, + clip_step=clip_step, + quick=quick, + verbose=False, + resolution=512, + aug_crop=16, + dist_type='linear_1_9', + aug_focal=1.0, + z_far=80) + + idxs = np.arange(0, len(dataset)-1, (len(dataset)-1)//10) + # idx = random.randint(0, len(dataset)-1) + # idx = 0 + for idx in idxs: + print(f"Visualizing scene {idx}...") + visualize_scene(idx) diff --git a/monst3r/dust3r/datasets/habitat.py b/monst3r/dust3r/datasets/habitat.py new file mode 100644 index 0000000..11ce8a0 --- /dev/null +++ b/monst3r/dust3r/datasets/habitat.py @@ -0,0 +1,107 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed habitat +# dataset at https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md +# See datasets_preprocess/habitat for more details +# -------------------------------------------------------- +import os.path as osp +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 # noqa +import numpy as np +from PIL import Image +import json + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset + + +class Habitat(BaseStereoViewDataset): + def __init__(self, size, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert self.split is not None + # loading list of scenes + with open(osp.join(self.ROOT, f'Habitat_{size}_scenes_{self.split}.txt')) as f: + self.scenes = f.read().splitlines() + self.instances = list(range(1, 5)) + + def filter_scene(self, label, instance=None): + if instance: + subscene, instance = instance.split('_') + label += '/' + subscene + self.instances = [int(instance) - 1] + valid = np.bool_([scene.startswith(label) for scene in self.scenes]) + assert sum(valid), 'no scene was selected for {label=} {instance=}' + self.scenes = [scene for i, scene in enumerate(self.scenes) if valid[i]] + + def _get_views(self, idx, resolution, rng): + scene = self.scenes[idx] + data_path, key = osp.split(osp.join(self.ROOT, scene)) + views = [] + two_random_views = [0, rng.choice(self.instances)] # view 0 is connected with all other views + for view_index in two_random_views: + # load the view (and use the next one if this one's broken) + for ii in range(view_index, view_index + 5): + image, depthmap, intrinsics, camera_pose = self._load_one_view(data_path, key, ii % 5, resolution, rng) + if np.isfinite(camera_pose).all(): + break + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='Habitat', + label=osp.relpath(data_path, self.ROOT), + instance=f"{key}_{view_index}")) + return views + + def _load_one_view(self, data_path, key, view_index, resolution, rng): + view_index += 1 # file indices starts at 1 + impath = osp.join(data_path, f"{key}_{view_index}.jpeg") + image = Image.open(impath) + + depthmap_filename = osp.join(data_path, f"{key}_{view_index}_depth.exr") + depthmap = cv2.imread(depthmap_filename, cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH) + + camera_params_filename = osp.join(data_path, f"{key}_{view_index}_camera_params.json") + with open(camera_params_filename, 'r') as f: + camera_params = json.load(f) + + intrinsics = np.float32(camera_params['camera_intrinsics']) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params['R_cam2world'] + camera_pose[:3, 3] = camera_params['t_cam2world'] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=impath) + return image, depthmap, intrinsics, camera_pose + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Habitat(1_000_000, split='train', ROOT="data/habitat_processed", + resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/monst3r/dust3r/datasets/megadepth.py b/monst3r/dust3r/datasets/megadepth.py new file mode 100644 index 0000000..8131498 --- /dev/null +++ b/monst3r/dust3r/datasets/megadepth.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed MegaDepth +# dataset at https://www.cs.cornell.edu/projects/megadepth/ +# See datasets_preprocess/preprocess_megadepth.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class MegaDepth(BaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data(self.split) + + if self.split is None: + pass + elif self.split == 'train': + self.select_scene(('0015', '0022'), opposite=True) + elif self.split == 'val': + self.select_scene(('0015', '0022')) + else: + raise ValueError(f'bad {self.split=}') + + def _load_data(self, split): + with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: + self.all_scenes = data['scenes'] + self.all_images = data['images'] + self.pairs = data['pairs'] + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.all_scenes)} scenes' + + def select_scene(self, scene, *instances, opposite=False): + scenes = (scene,) if isinstance(scene, str) else tuple(scene) + scene_id = [s.startswith(scenes) for s in self.all_scenes] + assert any(scene_id), 'no scene found' + + valid = np.in1d(self.pairs['scene_id'], np.nonzero(scene_id)[0]) + if instances: + image_id = [i.startswith(instances) for i in self.all_images] + image_id = np.nonzero(image_id)[0] + assert len(image_id), 'no instance found' + # both together? + if len(instances) == 2: + valid &= np.in1d(self.pairs['im1_id'], image_id) & np.in1d(self.pairs['im2_id'], image_id) + else: + valid &= np.in1d(self.pairs['im1_id'], image_id) | np.in1d(self.pairs['im2_id'], image_id) + + if opposite: + valid = ~valid + assert valid.any() + self.pairs = self.pairs[valid] + + def _get_views(self, pair_idx, resolution, rng): + scene_id, im1_id, im2_id, score = self.pairs[pair_idx] + + scene, subscene = self.all_scenes[scene_id].split() + seq_path = osp.join(self.ROOT, scene, subscene) + + views = [] + + for im_id in [im1_id, im2_id]: + img = self.all_images[im_id] + try: + image = imread_cv2(osp.join(seq_path, img + '.jpg')) + depthmap = imread_cv2(osp.join(seq_path, img + ".exr")) + camera_params = np.load(osp.join(seq_path, img + ".npz")) + except Exception as e: + raise OSError(f'cannot load {img}, got exception {e}') + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.float32(camera_params['cam2world']) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, img)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='MegaDepth', + label=osp.relpath(seq_path, self.ROOT), + instance=img)) + + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = MegaDepth(split='train', ROOT="data/megadepth_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/monst3r/dust3r/datasets/pointodyssey.py b/monst3r/dust3r/datasets/pointodyssey.py new file mode 100644 index 0000000..8d12b73 --- /dev/null +++ b/monst3r/dust3r/datasets/pointodyssey.py @@ -0,0 +1,259 @@ +import sys +sys.path.append('.') +import os +import torch +import numpy as np +import os.path as osp +import torchvision.transforms as transforms +import torch.nn.functional as F +from PIL import Image +from torch._C import dtype, set_flush_denormal +import dust3r.utils.po_utils.basic +import dust3r.utils.po_utils.improc +from dust3r.utils.po_utils.misc import farthest_point_sample_py +from dust3r.utils.po_utils.geom import apply_4x4_py, apply_pix_T_cam_py +import glob +import cv2 +from torchvision.transforms import ColorJitter, GaussianBlur +from functools import partial + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 +from dust3r.utils.misc import get_stride_distribution + +np.random.seed(125) +torch.multiprocessing.set_sharing_strategy('file_system') + +class PointOdysseyDUSt3R(BaseStereoViewDataset): + def __init__(self, + dataset_location='data/pointodyssey', + dset='train', + use_augs=False, + S=2, + N=16, + strides=[1,2,3,4,5,6,7,8,9], + clip_step=2, + quick=False, + verbose=False, + dist_type=None, + clip_step_last_skip = 0, + *args, + **kwargs + ): + + print('loading pointodyssey dataset...') + super().__init__(*args, **kwargs) + self.dataset_label = 'pointodyssey' + self.split = dset + self.S = S # stride + self.N = N # min num points + self.verbose = verbose + + self.use_augs = use_augs + self.dset = dset + + self.rgb_paths = [] + self.depth_paths = [] + self.normal_paths = [] + self.traj_paths = [] + self.annotation_paths = [] + self.full_idxs = [] + self.sample_stride = [] + self.strides = strides + + self.subdirs = [] + self.sequences = [] + self.subdirs.append(os.path.join(dataset_location, dset)) + + for subdir in self.subdirs: + for seq in glob.glob(os.path.join(subdir, "*/")): + seq_name = seq.split('/')[-1] + self.sequences.append(seq) + + self.sequences = sorted(self.sequences) + if self.verbose: + print(self.sequences) + print('found %d unique videos in %s (dset=%s)' % (len(self.sequences), dataset_location, dset)) + + ## load trajectories + print('loading trajectories...') + + if quick: + self.sequences = self.sequences[1:2] + + for seq in self.sequences: + if self.verbose: + print('seq', seq) + + rgb_path = os.path.join(seq, 'rgbs') + info_path = os.path.join(seq, 'info.npz') + annotations_path = os.path.join(seq, 'anno.npz') + + if os.path.isfile(info_path) and os.path.isfile(annotations_path): + + info = np.load(info_path, allow_pickle=True) + trajs_3d_shape = info['trajs_3d'].astype(np.float32) + + if len(trajs_3d_shape) and trajs_3d_shape[1] > self.N: + + for stride in strides: + for ii in range(0,len(os.listdir(rgb_path))-self.S*max(stride,clip_step_last_skip)+1, clip_step): + full_idx = ii + np.arange(self.S)*stride + self.rgb_paths.append([os.path.join(seq, 'rgbs', 'rgb_%05d.jpg' % idx) for idx in full_idx]) + self.depth_paths.append([os.path.join(seq, 'depths', 'depth_%05d.png' % idx) for idx in full_idx]) + self.normal_paths.append([os.path.join(seq, 'normals', 'normal_%05d.jpg' % idx) for idx in full_idx]) + self.annotation_paths.append(os.path.join(seq, 'anno.npz')) + self.full_idxs.append(full_idx) + self.sample_stride.append(stride) + if self.verbose: + sys.stdout.write('.') + sys.stdout.flush() + elif self.verbose: + print('rejecting seq for missing 3d') + elif self.verbose: + print('rejecting seq for missing info or anno') + + self.stride_counts = {} + self.stride_idxs = {} + for stride in strides: + self.stride_counts[stride] = 0 + self.stride_idxs[stride] = [] + for i, stride in enumerate(self.sample_stride): + self.stride_counts[stride] += 1 + self.stride_idxs[stride].append(i) + print('stride counts:', self.stride_counts) + + if len(strides) > 1 and dist_type is not None: + self._resample_clips(strides, dist_type) + + print('collected %d clips of length %d in %s (dset=%s)' % ( + len(self.rgb_paths), self.S, dataset_location, dset)) + + def _resample_clips(self, strides, dist_type): + + # Get distribution of strides, and sample based on that + dist = get_stride_distribution(strides, dist_type=dist_type) + dist = dist / np.max(dist) + max_num_clips = self.stride_counts[strides[np.argmax(dist)]] + num_clips_each_stride = [min(self.stride_counts[stride], int(dist[i]*max_num_clips)) for i, stride in enumerate(strides)] + print('resampled_num_clips_each_stride:', num_clips_each_stride) + resampled_idxs = [] + for i, stride in enumerate(strides): + resampled_idxs += np.random.choice(self.stride_idxs[stride], num_clips_each_stride[i], replace=False).tolist() + + self.rgb_paths = [self.rgb_paths[i] for i in resampled_idxs] + self.depth_paths = [self.depth_paths[i] for i in resampled_idxs] + self.normal_paths = [self.normal_paths[i] for i in resampled_idxs] + self.annotation_paths = [self.annotation_paths[i] for i in resampled_idxs] + self.full_idxs = [self.full_idxs[i] for i in resampled_idxs] + self.sample_stride = [self.sample_stride[i] for i in resampled_idxs] + + def __len__(self): + return len(self.rgb_paths) + + def _get_views(self, index, resolution, rng): + + rgb_paths = self.rgb_paths[index] + depth_paths = self.depth_paths[index] + normal_paths = self.normal_paths[index] + full_idx = self.full_idxs[index] + annotations_path = self.annotation_paths[index] + annotations = np.load(annotations_path, allow_pickle=True) + pix_T_cams = annotations['intrinsics'][full_idx].astype(np.float32) + world_T_cams = annotations['extrinsics'][full_idx].astype(np.float32) + + views = [] + for i in range(2): + + impath = rgb_paths[i] + depthpath = depth_paths[i] + normalpath = normal_paths[i] + + # load camera params + extrinsics = world_T_cams[i] + R = extrinsics[:3,:3] + t = extrinsics[:3,3] + camera_pose = np.eye(4, dtype=np.float32) # cam_2_world + camera_pose[:3,:3] = R.T + camera_pose[:3,3] = -R.T @ t + intrinsics = pix_T_cams[i] + + # load image and depth + rgb_image = imread_cv2(impath) + depth16 = cv2.imread(depthpath, cv2.IMREAD_ANYDEPTH) + depthmap = depth16.astype(np.float32) / 65535.0 * 1000.0 # 1000 is the max depth in the dataset + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=rgb_paths[i].split('/')[-3], + instance=osp.split(rgb_paths[i])[1], + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + import gradio as gr + import random + + dataset_location = 'data/point_odyssey' # Change this to the correct path + dset = 'train' + use_augs = False + S = 2 + N = 1 + strides = [1,2,3,4,5,6,7,8,9] + clip_step = 2 + quick = False # Set to True for quick testing + + def visualize_scene(idx): + views = dataset[idx] + assert len(views) == 2 + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.25) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(255, 0, 0), + image=colors, + cam_size=cam_size) + os.makedirs('./tmp/po', exist_ok=True) + path = f"./tmp/po/po_scene_{idx}.glb" + return viz.save_glb(path) + + dataset = PointOdysseyDUSt3R( + dataset_location=dataset_location, + dset=dset, + use_augs=use_augs, + S=S, + N=N, + strides=strides, + clip_step=clip_step, + quick=quick, + verbose=False, + resolution=224, + aug_crop=16, + dist_type='linear_9_1', + aug_focal=1.5, + z_far=80) +# around 514k samples + + idxs = np.arange(0, len(dataset)-1, (len(dataset)-1)//10) + # idx = random.randint(0, len(dataset)-1) + # idx = 0 + for idx in idxs: + print(f"Visualizing scene {idx}...") + visualize_scene(idx) diff --git a/monst3r/dust3r/datasets/scannetpp.py b/monst3r/dust3r/datasets/scannetpp.py new file mode 100644 index 0000000..520deed --- /dev/null +++ b/monst3r/dust3r/datasets/scannetpp.py @@ -0,0 +1,96 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed scannet++ +# dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes +# https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf +# See datasets_preprocess/preprocess_scannetpp.py +# -------------------------------------------------------- +import os.path as osp +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class ScanNetpp(BaseStereoViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert self.split == 'train' + self.loaded_data = self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: + self.scenes = data['scenes'] + self.sceneids = data['sceneids'] + self.images = data['images'] + self.intrinsics = data['intrinsics'].astype(np.float32) + self.trajectories = data['trajectories'].astype(np.float32) + self.pairs = data['pairs'][:, :2].astype(int) + + def __len__(self): + return len(self.pairs) + + def _get_views(self, idx, resolution, rng): + + image_idx1, image_idx2 = self.pairs[idx] + + views = [] + for view_idx in [image_idx1, image_idx2]: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename + '.jpg')) + # Load depthmap + depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename + '.png'), cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='ScanNet++', + label=self.scenes[scene_id] + '_' + basename, + instance=f'{str(idx)}_{str(view_idx)}', + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = ScanNetpp(split='train', ROOT="data/scannetpp_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx*255, (1 - idx)*255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/monst3r/dust3r/datasets/sintel.py b/monst3r/dust3r/datasets/sintel.py new file mode 100644 index 0000000..41929fc --- /dev/null +++ b/monst3r/dust3r/datasets/sintel.py @@ -0,0 +1,274 @@ +import sys +sys.path.append('.') +import os +import torch +import numpy as np +import os.path as osp +import glob +import PIL.Image +import torchvision.transforms as tvf + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2, crop_img +from dust3r.utils.misc import get_stride_distribution + +np.random.seed(125) +torch.multiprocessing.set_sharing_strategy('file_system') +TAG_FLOAT = 202021.25 +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) +ToTensor = tvf.ToTensor() + +def depth_read(filename): + """ Read depth data from file, return as numpy array. """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + width = np.fromfile(f,dtype=np.int32,count=1)[0] + height = np.fromfile(f,dtype=np.int32,count=1)[0] + size = width*height + assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) + depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) + return depth + +def cam_read(filename): + """ Read camera data, return (M,N) tuple. + + M is the intrinsic matrix, N is the extrinsic matrix, so that + + x = M*N*X, + with x being a point in homogeneous image pixel coordinates, X being a + point in homogeneous world coordinates. + """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) + N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) + return M,N + +class SintelDUSt3R(BaseStereoViewDataset): + def __init__(self, + dataset_location='data/sintel/training', + dset='clean', + use_augs=False, + S=2, + strides=[7], + clip_step=2, + quick=False, + verbose=False, + dist_type=None, + clip_step_last_skip = 0, + load_dynamic_mask=True, + *args, + **kwargs + ): + + print('loading sintel dataset...') + super().__init__(*args, **kwargs) + self.dataset_label = 'sintel' + self.split = dset + self.S = S # stride + self.verbose = verbose + self.load_dynamic_mask = load_dynamic_mask + + self.use_augs = use_augs + self.dset = dset + + self.rgb_paths = [] + self.depth_paths = [] + self.traj_paths = [] + self.annotation_paths = [] + self.dynamic_mask_paths = [] + self.full_idxs = [] + self.sample_stride = [] + self.strides = strides + + self.subdirs = [] + self.sequences = [] + self.subdirs.append(os.path.join(dataset_location, dset)) + + for subdir in self.subdirs: + for seq in glob.glob(os.path.join(subdir, "*/")): + self.sequences.append(seq) + + self.sequences = sorted(self.sequences) + if self.verbose: + print(self.sequences) + print('found %d unique videos in %s (dset=%s)' % (len(self.sequences), dataset_location, dset)) + + ## load trajectories + print('loading trajectories...') + + if quick: + self.sequences = self.sequences[1:2] + + for seq in self.sequences: + if self.verbose: + print('seq', seq) + + rgb_path = seq + depth_path = seq.replace(dset,'depth') + caminfo_path = seq.replace(dset,'camdata_left') + dynamic_mask_path = seq.replace(dset,'dynamic_label_perfect') + + for stride in strides: + for ii in range(1,len(os.listdir(rgb_path))-self.S*max(stride,clip_step_last_skip)+1, clip_step): + full_idx = ii + np.arange(self.S)*stride + self.rgb_paths.append([os.path.join(rgb_path, 'frame_%04d.png' % idx) for idx in full_idx]) + self.depth_paths.append([os.path.join(depth_path, 'frame_%04d.dpt' % idx) for idx in full_idx]) + self.annotation_paths.append([os.path.join(caminfo_path, 'frame_%04d.cam' % idx) for idx in full_idx]) + self.dynamic_mask_paths.append([os.path.join(dynamic_mask_path, 'frame_%04d.png' % idx) for idx in full_idx]) + self.full_idxs.append(full_idx) + self.sample_stride.append(stride) + if self.verbose: + sys.stdout.write('.') + sys.stdout.flush() + + self.stride_counts = {} + self.stride_idxs = {} + for stride in strides: + self.stride_counts[stride] = 0 + self.stride_idxs[stride] = [] + for i, stride in enumerate(self.sample_stride): + self.stride_counts[stride] += 1 + self.stride_idxs[stride].append(i) + print('stride counts:', self.stride_counts) + + if len(strides) > 1 and dist_type is not None: + self._resample_clips(strides, dist_type) + + print('collected %d clips of length %d in %s (dset=%s)' % ( + len(self.rgb_paths), self.S, dataset_location, dset)) + + def _resample_clips(self, strides, dist_type): + + # Get distribution of strides, and sample based on that + dist = get_stride_distribution(strides, dist_type=dist_type) + dist = dist / np.max(dist) + max_num_clips = self.stride_counts[strides[np.argmax(dist)]] + num_clips_each_stride = [min(self.stride_counts[stride], int(dist[i]*max_num_clips)) for i, stride in enumerate(strides)] + print('resampled_num_clips_each_stride:', num_clips_each_stride) + resampled_idxs = [] + for i, stride in enumerate(strides): + resampled_idxs += np.random.choice(self.stride_idxs[stride], num_clips_each_stride[i], replace=False).tolist() + + self.rgb_paths = [self.rgb_paths[i] for i in resampled_idxs] + self.depth_paths = [self.depth_paths[i] for i in resampled_idxs] + self.annotation_paths = [self.annotation_paths[i] for i in resampled_idxs] + self.dynamic_mask_paths = [self.dynamic_mask_paths[i] for i in resampled_idxs] + self.full_idxs = [self.full_idxs[i] for i in resampled_idxs] + self.sample_stride = [self.sample_stride[i] for i in resampled_idxs] + + def __len__(self): + return len(self.rgb_paths) + + def _get_views(self, index, resolution, rng): + + rgb_paths = self.rgb_paths[index] + depth_paths = self.depth_paths[index] + full_idx = self.full_idxs[index] + annotations_paths = self.annotation_paths[index] + dynamic_mask_paths = self.dynamic_mask_paths[index] + + views = [] + for i in range(2): + impath = rgb_paths[i] + depthpath = depth_paths[i] + dynamic_mask_path = dynamic_mask_paths[i] + + # load camera params + intrinsics, extrinsics = cam_read(annotations_paths[i]) + intrinsics, extrinsics = np.array(intrinsics, dtype=np.float32), np.array(extrinsics, dtype=np.float32) + R = extrinsics[:3,:3] + t = extrinsics[:3,3] + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3,:3] = R.T + camera_pose[:3,3] = -R.T @ t + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = depth_read(depthpath) + + # load dynamic mask + if dynamic_mask_path is not None and os.path.exists(dynamic_mask_path): + dynamic_mask = PIL.Image.open(dynamic_mask_path).convert('L') + dynamic_mask = ToTensor(dynamic_mask).sum(0).numpy() + _, dynamic_mask, _ = self._crop_resize_if_necessary( + rgb_image, dynamic_mask, intrinsics, resolution, rng=rng, info=impath) + dynamic_mask = dynamic_mask > 0.5 + assert not np.all(dynamic_mask), f"Dynamic mask is all True for {impath}" + else: + dynamic_mask = np.ones((resolution[1],resolution[0]), dtype=bool) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) + + if self.load_dynamic_mask: + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=rgb_paths[i].split('/')[-2], + instance=osp.split(rgb_paths[i])[1], + dynamic_mask=dynamic_mask, + )) + else: + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=rgb_paths[i].split('/')[-2], + instance=osp.split(rgb_paths[i])[1], + )) + return views + + +if __name__ == "__main__": + + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + use_augs = False + S = 2 + strides = [1] + clip_step = 1 + quick = False # Set to True for quick testing + + + def visualize_scene(idx): + views = dataset[idx] + assert len(views) == 2 + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.25) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(255, 0, 0), + image=colors, + cam_size=cam_size) + path = f"./tmp/sintel_scene_{idx}.glb" + return viz.save_glb(path) + + dataset = SintelDUSt3R( + use_augs=use_augs, + S=S, + strides=strides, + clip_step=clip_step, + quick=quick, + verbose=False, + resolution=(512,224), + seed = 777, + clip_step_last_skip=0, + aug_crop=16) + + idx = random.randint(0, len(dataset)-1) + visualize_scene(idx) \ No newline at end of file diff --git a/monst3r/dust3r/datasets/spring_dataset.py b/monst3r/dust3r/datasets/spring_dataset.py new file mode 100644 index 0000000..a76c9af --- /dev/null +++ b/monst3r/dust3r/datasets/spring_dataset.py @@ -0,0 +1,242 @@ +import sys +sys.path.append('.') +import os +import torch +import numpy as np +import os.path as osp +import glob + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 +from dust3r.utils.misc import get_stride_distribution +import h5py + +SPRING_BASELINE = 0.065 +np.random.seed(125) +torch.multiprocessing.set_sharing_strategy('file_system') + +def get_depth(disp1, intrinsics, baseline=SPRING_BASELINE): + """ + get depth from reference frame disparity and camera intrinsics + """ + return intrinsics[0] * baseline / disp1 + + +def readDsp5Disp(filename): + with h5py.File(filename, "r") as f: + if "disparity" not in f.keys(): + raise IOError(f"File {filename} does not have a 'disparity' key. Is this a valid dsp5 file?") + return f["disparity"][()] + + +class SpringDUSt3R(BaseStereoViewDataset): + def __init__(self, + dataset_location='data/spring', + dset='train', + use_augs=False, + S=2, + strides=[8], + clip_step=2, + quick=False, + verbose=False, + dist_type=None, + remove_seq_list=[], + *args, + **kwargs + ): + + print('loading Spring dataset...') + super().__init__(*args, **kwargs) + self.dataset_label = 'Spring' + self.split = dset + self.S = S # number of frames + self.verbose = verbose + + self.use_augs = use_augs + self.dset = dset + + self.rgb_paths = [] + self.depth_paths = [] + self.annotations = [] + self.intrinsics = [] + self.full_idxs = [] + self.sample_stride = [] + self.strides = strides + + self.subdirs = [] + self.sequences = [] + self.subdirs.append(os.path.join(dataset_location, dset)) + + for subdir in self.subdirs: + for seq in glob.glob(os.path.join(subdir, "*/")): + seq_name = seq.split('/')[-2] + print(f"Processing sequence {seq_name}") + # remove_seq_list = ['0008', '0041', '0043'] + if seq_name in remove_seq_list: + print(f"Skipping sequence {seq_name}") + continue + self.sequences.append(seq) + + self.sequences = sorted(self.sequences) + if self.verbose: + print(self.sequences) + print('found %d unique videos in %s (dset=%s)' % (len(self.sequences), dataset_location, dset)) + + if quick: + self.sequences = self.sequences[1:2] + + for seq in self.sequences: + if self.verbose: + print('seq', seq) + + rgb_path = os.path.join(seq, 'frame_left') + depth_path = os.path.join(seq, 'disp1_left') + caminfo_path = os.path.join(seq, 'cam_data/extrinsics.txt') + caminfo = np.loadtxt(caminfo_path) + intrinsics_path = os.path.join(seq, 'cam_data/intrinsics.txt') + intrinsics = np.loadtxt(intrinsics_path) + + for stride in strides: + for ii in range(1,len(os.listdir(rgb_path))-self.S*stride+2, clip_step): + full_idx = ii + np.arange(self.S)*stride + self.rgb_paths.append([os.path.join(rgb_path, 'frame_left_%04d.png' % idx) for idx in full_idx]) + self.depth_paths.append([os.path.join(depth_path, 'disp1_left_%04d.dsp5' % idx) for idx in full_idx]) + self.annotations.append(caminfo[full_idx-1]) + self.intrinsics.append(intrinsics[full_idx-1]) + self.full_idxs.append(full_idx) + self.sample_stride.append(stride) + if self.verbose: + sys.stdout.write('.') + sys.stdout.flush() + + self.stride_counts = {} + self.stride_idxs = {} + for stride in strides: + self.stride_counts[stride] = 0 + self.stride_idxs[stride] = [] + for i, stride in enumerate(self.sample_stride): + self.stride_counts[stride] += 1 + self.stride_idxs[stride].append(i) + print('stride counts:', self.stride_counts) + + if len(strides) > 1 and dist_type is not None: + self._resample_clips(strides, dist_type) + + print('collected %d clips of length %d in %s (dset=%s)' % ( + len(self.rgb_paths), self.S, dataset_location, dset)) + + def _resample_clips(self, strides, dist_type): + + # Get distribution of strides, and sample based on that + dist = get_stride_distribution(strides, dist_type=dist_type) + dist = dist / np.max(dist) + max_num_clips = self.stride_counts[strides[np.argmax(dist)]] + num_clips_each_stride = [min(self.stride_counts[stride], int(dist[i]*max_num_clips)) for i, stride in enumerate(strides)] + print('resampled_num_clips_each_stride:', num_clips_each_stride) + resampled_idxs = [] + for i, stride in enumerate(strides): + resampled_idxs += np.random.choice(self.stride_idxs[stride], num_clips_each_stride[i], replace=False).tolist() + + self.rgb_paths = [self.rgb_paths[i] for i in resampled_idxs] + self.depth_paths = [self.depth_paths[i] for i in resampled_idxs] + self.annotations = [self.annotations[i] for i in resampled_idxs] + self.intrinsics = [self.intrinsics[i] for i in resampled_idxs] + self.full_idxs = [self.full_idxs[i] for i in resampled_idxs] + self.sample_stride = [self.sample_stride[i] for i in resampled_idxs] + + def __len__(self): + return len(self.rgb_paths) + + def _get_views(self, index, resolution, rng): + + rgb_paths = self.rgb_paths[index] + depth_paths = self.depth_paths[index] + annotations = self.annotations[index] + intrinsics = self.intrinsics[index] + + views = [] + for i in range(2): + impath = rgb_paths[i] + depthpath = depth_paths[i] + + # load camera params + extrinsic = np.reshape(annotations[i], (4,4)).astype(np.float32) + camera_pose = np.linalg.inv(extrinsic) + intrinsic_matrix = np.zeros((3, 3), dtype=np.float32) + intrinsic_matrix[0, 0] = intrinsics[i][0] + intrinsic_matrix[1, 1] = intrinsics[i][1] + intrinsic_matrix[0, 2] = intrinsics[i][2] + intrinsic_matrix[1, 2] = intrinsics[i][3] + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = get_depth(readDsp5Disp(depthpath), intrinsics[i]).astype(np.float32) + depthmap = depthmap[::2, ::2] + depthmap = np.where(np.isnan(depthmap), -1, depthmap) + depthmap = np.where(np.isinf(depthmap), -1, depthmap) + + rgb_image, depthmap, intrinsic_matrix = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsic_matrix, resolution, rng=rng, info=impath) + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsic_matrix, + dataset=self.dataset_label, + label=rgb_paths[i].split('/')[-3], + instance=osp.split(rgb_paths[i])[1], + )) + return views + +if __name__ == "__main__": + + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + use_augs = False + S = 2 + strides=[2,4,6,8,10,12,14,16,18] + clip_step = 2 + quick = False # Set to True for quick testing + dist_type='linear_9_1' + + + def visualize_scene(idx): + views = dataset[idx] + assert len(views) == 2 + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 1) + label = views[0]['label'] + instance = views[0]['instance'] + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(255, 0, 0), + image=colors, + cam_size=cam_size) + os.makedirs("./tmp/spring_scene", exist_ok=True) + path = f"./tmp/spring_scene/spring_scene_{label}_{views[0]['instance']}_{views[1]['instance']}.glb" + return viz.save_glb(path) + + dataset = SpringDUSt3R( + use_augs=use_augs, + S=S, + strides=strides, + clip_step=clip_step, + quick=quick, + verbose=False, + resolution=(512,288), + aug_crop=16, + dist_type=dist_type, + z_far=80) + + idxs = np.arange(0, len(dataset)-1, (len(dataset)-1)//10) + for idx in idxs: + print(f"Visualizing scene {idx}...") + visualize_scene(idx) \ No newline at end of file diff --git a/monst3r/dust3r/datasets/staticthings3d.py b/monst3r/dust3r/datasets/staticthings3d.py new file mode 100644 index 0000000..e7f70f0 --- /dev/null +++ b/monst3r/dust3r/datasets/staticthings3d.py @@ -0,0 +1,96 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed StaticThings3D +# dataset at https://github.com/lmb-freiburg/robustmvd/ +# See datasets_preprocess/preprocess_staticthings3d.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class StaticThings3D (BaseStereoViewDataset): + """ Dataset of indoor scenes, 5 images each time + """ + def __init__(self, ROOT, *args, mask_bg='rand', **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + + assert mask_bg in (True, False, 'rand') + self.mask_bg = mask_bg + + # loading all pairs + assert self.split is None + self.pairs = np.load(osp.join(ROOT, 'staticthings_pairs.npy')) + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs' + + def _get_views(self, pair_idx, resolution, rng): + scene, seq, cam1, im1, cam2, im2 = self.pairs[pair_idx] + seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') + + views = [] + + mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) + + CAM = {b'l':'left', b'r':'right'} + for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: + num = f"{idx:04n}" + img = num+"_clean.jpg" if rng.choice(2) else num+"_final.jpg" + image = imread_cv2(osp.join(self.ROOT, seq_path, cam, img)) + depthmap = imread_cv2(osp.join(self.ROOT, seq_path, cam, num+".exr")) + camera_params = np.load(osp.join(self.ROOT, seq_path, cam, num+".npz")) + + intrinsics = camera_params['intrinsics'] + camera_pose = camera_params['cam2world'] + + if mask_bg: + depthmap[depthmap > 200] = 0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng, info=(seq_path,cam,img)) + + views.append(dict( + img = image, + depthmap = depthmap, + camera_pose = camera_pose, # cam2world + camera_intrinsics = intrinsics, + dataset = 'StaticThings3D', + label = seq_path, + instance = cam+'_'+img)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = StaticThings3D(ROOT="data/staticthings3d_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx*255, (1 - idx)*255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/monst3r/dust3r/datasets/tartanair.py b/monst3r/dust3r/datasets/tartanair.py new file mode 100644 index 0000000..3bbcb11 --- /dev/null +++ b/monst3r/dust3r/datasets/tartanair.py @@ -0,0 +1,234 @@ + +import sys +sys.path.append('.') +import os +import torch +import numpy as np +import os.path as osp +import glob + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 +from dust3r.utils.misc import get_stride_distribution + +np.random.seed(125) +torch.multiprocessing.set_sharing_strategy('file_system') + +def depth_read(filename): + depth = np.load(filename) + return depth + +def xyzqxqyqxqw_to_c2w(xyzqxqyqxqw): + xyzqxqyqxqw = np.array(xyzqxqyqxqw, dtype=np.float32) + #NOTE: we need to convert x_y_z coordinate system to z_x_y coordinate system + z, x, y = xyzqxqyqxqw[:3] + qz, qx, qy, qw = xyzqxqyqxqw[3:] + c2w = np.eye(4) + c2w[:3, :3] = np.array([ + [1 - 2*qy*qy - 2*qz*qz, 2*qx*qy - 2*qz*qw, 2*qx*qz + 2*qy*qw], + [2*qx*qy + 2*qz*qw, 1 - 2*qx*qx - 2*qz*qz, 2*qy*qz - 2*qx*qw], + [2*qx*qz - 2*qy*qw, 2*qy*qz + 2*qx*qw, 1 - 2*qx*qx - 2*qy*qy] + ]) + c2w[:3, 3] = np.array([x, y, z]) + return c2w + +class TarTanAirDUSt3R(BaseStereoViewDataset): + def __init__(self, + dataset_location='data/tartanair', + dset='Hard', + use_augs=False, + S=2, + strides=[8], + clip_step=2, + quick=False, + verbose=False, + dist_type=None, + *args, + **kwargs + ): + + print('loading tartanair dataset...') + super().__init__(*args, **kwargs) + self.dataset_label = 'tartanair' + self.split = dset + self.S = S # number of frames + self.verbose = verbose + + self.use_augs = use_augs + self.dset = dset + + self.rgb_paths = [] + self.depth_paths = [] + self.normal_paths = [] + self.traj_paths = [] + self.annotations = [] + self.full_idxs = [] + self.sample_stride = [] + self.strides = strides + + self.subdirs = [] + self.sequences = [] + self.subdirs.append(os.path.join(dataset_location)) #'data/tartanair' + + for subdir in self.subdirs: + for seq in glob.glob(os.path.join(subdir, "*/", dset, "*/")): + self.sequences.append(seq) + + self.sequences = sorted(self.sequences) + if self.verbose: + print(self.sequences) + print('found %d unique videos in %s (dset=%s)' % (len(self.sequences), dataset_location, dset)) + + if quick: + self.sequences = self.sequences[1:2] + + for seq in self.sequences: + if self.verbose: + print('seq', seq) + + rgb_path = os.path.join(seq, 'image_left') + depth_path = os.path.join(seq, 'depth_left') + caminfo_path = os.path.join(seq, 'pose_left.txt') + caminfo = np.loadtxt(caminfo_path) + + for stride in strides: + for ii in range(0,len(os.listdir(rgb_path))-self.S*stride+1, clip_step): + full_idx = ii + np.arange(self.S)*stride + self.rgb_paths.append([os.path.join(rgb_path, '%06d_left.png' % idx) for idx in full_idx]) + self.depth_paths.append([os.path.join(depth_path, '%06d_left_depth.npy' % idx) for idx in full_idx]) + self.annotations.append(caminfo[full_idx]) + self.full_idxs.append(full_idx) + self.sample_stride.append(stride) + if self.verbose: + sys.stdout.write('.') + sys.stdout.flush() + + + fx = 320.0 # focal length x + fy = 320.0 # focal length y + cx = 320.0 # optical center x + cy = 240.0 # optical center y + + width = 640 + height = 480 + + self.intrinsics = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + self.stride_counts = {} + self.stride_idxs = {} + for stride in strides: + self.stride_counts[stride] = 0 + self.stride_idxs[stride] = [] + for i, stride in enumerate(self.sample_stride): + self.stride_counts[stride] += 1 + self.stride_idxs[stride].append(i) + print('stride counts:', self.stride_counts) + + if len(strides) > 1 and dist_type is not None: + self._resample_clips(strides, dist_type) + + print('collected %d clips of length %d in %s (dset=%s)' % ( + len(self.rgb_paths), self.S, dataset_location, dset)) + + def _resample_clips(self, strides, dist_type): + + # Get distribution of strides, and sample based on that + dist = get_stride_distribution(strides, dist_type=dist_type) + dist = dist / np.max(dist) + max_num_clips = self.stride_counts[strides[np.argmax(dist)]] + num_clips_each_stride = [min(self.stride_counts[stride], int(dist[i]*max_num_clips)) for i, stride in enumerate(strides)] + print('resampled_num_clips_each_stride:', num_clips_each_stride) + resampled_idxs = [] + for i, stride in enumerate(strides): + resampled_idxs += np.random.choice(self.stride_idxs[stride], num_clips_each_stride[i], replace=False).tolist() + + self.rgb_paths = [self.rgb_paths[i] for i in resampled_idxs] + self.depth_paths = [self.depth_paths[i] for i in resampled_idxs] + self.annotations = [self.annotations[i] for i in resampled_idxs] + self.full_idxs = [self.full_idxs[i] for i in resampled_idxs] + self.sample_stride = [self.sample_stride[i] for i in resampled_idxs] + + def __len__(self): + return len(self.rgb_paths) + + def _get_views(self, index, resolution, rng): + + rgb_paths = self.rgb_paths[index] + depth_paths = self.depth_paths[index] + full_idx = self.full_idxs[index] + annotations = self.annotations[index] + + views = [] + for i in range(2): + impath = rgb_paths[i] + depthpath = depth_paths[i] + + # load camera params + camera_pose = np.array(xyzqxqyqxqw_to_c2w(annotations[i]), dtype=np.float32) + # camera_pose = np.linalg.inv(camera_pose) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = depth_read(depthpath) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, self.intrinsics, resolution, rng=rng, info=impath) + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=rgb_paths[i].split('/')[-5]+'-'+rgb_paths[i].split('/')[-3], + instance=osp.split(rgb_paths[i])[1], + )) + return views + +if __name__ == "__main__": + + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + use_augs = False + S = 2 + strides = [1,2,3,4,5,6,7,8,9] + clip_step = 2 + quick = False # Set to True for quick testing + + + def visualize_scene(idx): + views = dataset[idx] + assert len(views) == 2 + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 1) + label = views[0]['label'] + instance = views[0]['instance'] + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(255, 0, 0), + image=colors, + cam_size=cam_size) + path = f"./tmp/tartanair/tartanair_scene_{label}_{instance}.glb" + return viz.save_glb(path) + + dataset = TarTanAirDUSt3R( + use_augs=use_augs, + S=S, + strides=strides, + clip_step=clip_step, + quick=quick, + verbose=False, + resolution=(512,384), + dist_type='linear_9_1', + aug_crop=16) + + idxs = np.arange(0, len(dataset)-1, (len(dataset)-1)//10) + for idx in idxs: + print(f"Visualizing scene {idx}...") + visualize_scene(idx) \ No newline at end of file diff --git a/monst3r/dust3r/datasets/utils/__init__.py b/monst3r/dust3r/datasets/utils/__init__.py new file mode 100644 index 0000000..a326921 --- /dev/null +++ b/monst3r/dust3r/datasets/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/monst3r/dust3r/datasets/utils/cropping.py b/monst3r/dust3r/datasets/utils/cropping.py new file mode 100644 index 0000000..bc9eb89 --- /dev/null +++ b/monst3r/dust3r/datasets/utils/cropping.py @@ -0,0 +1,183 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """ Convenience class to aply the same operation to a whole set of images. + """ + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch('resize', *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch('crop', *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True): + """ Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + # define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic) + if depthmap is not None: + depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, + fy=scale_final, interpolation=cv2.INTER_NEAREST) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) + + return image.to_pil(), depthmap, camera_intrinsics + +def center_crop_image_depthmap(image, depthmap, camera_intrinsics, crop_scale): + """ + Jointly center-crop an image and its depthmap, and adjust the camera intrinsics accordingly. + + Parameters: + - image: PIL.Image or similar, the input image. + - depthmap: np.ndarray, the corresponding depth map. + - camera_intrinsics: np.ndarray, the 3x3 camera intrinsics matrix. + - crop_scale: float between 0 and 1, the fraction of the image to keep. + + Returns: + - cropped_image: PIL.Image, the center-cropped image. + - cropped_depthmap: np.ndarray, the center-cropped depth map. + - adjusted_intrinsics: np.ndarray, the adjusted camera intrinsics matrix. + """ + # Ensure crop_scale is valid + assert 0 < crop_scale <= 1, "crop_scale must be between 0 and 1" + + # Convert image to ImageList for consistent processing + image = ImageList(image) + input_resolution = np.array(image.size) # (width, height) + if depthmap is not None: + # Ensure depthmap matches the image size + assert depthmap.shape[:2] == tuple(image.size[::-1]), "Depthmap size must match image size" + + # Compute output resolution after cropping + output_resolution = np.floor(input_resolution * crop_scale).astype(int) + # get the correct crop_scale + crop_scale = output_resolution / input_resolution + + # Compute margins (amount to crop from each side) + margins = input_resolution - output_resolution + offset = margins / 2 # Since we are center cropping + + # Calculate the crop bounding box + l, t = offset.astype(int) + r = l + output_resolution[0] + b = t + output_resolution[1] + crop_bbox = (l, t, r, b) + + # Crop the image and depthmap + image = image.crop(crop_bbox) + if depthmap is not None: + depthmap = depthmap[t:b, l:r] + + # Adjust the camera intrinsics + adjusted_intrinsics = camera_intrinsics.copy() + + # Adjust focal lengths (fx, fy) # no need to adjust focal lengths for cropping + # adjusted_intrinsics[0, 0] /= crop_scale[0] # fx + # adjusted_intrinsics[1, 1] /= crop_scale[1] # fy + + # Adjust principal point (cx, cy) + adjusted_intrinsics[0, 2] -= l # cx + adjusted_intrinsics[1, 2] -= t # cy + + return image.to_pil(), depthmap, adjusted_intrinsics + + + +def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/monst3r/dust3r/datasets/utils/transforms.py b/monst3r/dust3r/datasets/utils/transforms.py new file mode 100644 index 0000000..eb34f2f --- /dev/null +++ b/monst3r/dust3r/datasets/utils/transforms.py @@ -0,0 +1,11 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf +from dust3r.utils.image import ImgNorm + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) diff --git a/monst3r/dust3r/datasets/waymo.py b/monst3r/dust3r/datasets/waymo.py new file mode 100644 index 0000000..8a782ff --- /dev/null +++ b/monst3r/dust3r/datasets/waymo.py @@ -0,0 +1,100 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed WayMo +# dataset at https://github.com/waymo-research/waymo-open-dataset +# See datasets_preprocess/preprocess_waymo.py +# -------------------------------------------------------- +import sys +sys.path.append('.') +import os +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class Waymo (BaseStereoViewDataset): + """ Dataset of outdoor street scenes, 5 images each time + """ + + def __init__(self, *args, ROOT, pairs_npz_name='waymo_pairs_video.npz', **kwargs): + self.ROOT = ROOT + self.pairs_npz_name = pairs_npz_name + super().__init__(*args, **kwargs) + self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, self.pairs_npz_name)) as data: + self.scenes = data['scenes'] + self.frames = data['frames'] + self.inv_frames = {frame: i for i, frame in enumerate(data['frames'])} + self.pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) + assert self.pairs[:, 0].max() == len(self.scenes) - 1 + print(f'Loaded {self.get_stats()}') + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.scenes)} scenes' + + def _get_views(self, pair_idx, resolution, rng): + seq, img1, img2 = self.pairs[pair_idx] + seq_path = osp.join(self.ROOT, self.scenes[seq]) + + views = [] + + for view_index in [img1, img2]: + impath = self.frames[view_index] + image = imread_cv2(osp.join(seq_path, impath + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) + camera_params = np.load(osp.join(seq_path, impath + ".npz")) + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.float32(camera_params['cam2world']) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='Waymo', + label=osp.relpath(seq_path, self.ROOT), + instance=impath)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Waymo(split='train', ROOT="data/waymo_processed", resolution=512, aug_crop=16) + idxs = np.arange(0, len(dataset)-1, (len(dataset)-1)//10) + for idx in idxs: + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + os.makedirs('./tmp/waymo', exist_ok=True) + path = f"./tmp/waymo/waymo_scene_{idx}.glb" + viz.save_glb(path) diff --git a/monst3r/dust3r/datasets/wildrgbd.py b/monst3r/dust3r/datasets/wildrgbd.py new file mode 100644 index 0000000..c41dd0b --- /dev/null +++ b/monst3r/dust3r/datasets/wildrgbd.py @@ -0,0 +1,67 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed WildRGB-D +# dataset at https://github.com/wildrgbd/wildrgbd/ +# See datasets_preprocess/preprocess_wildrgbd.py +# -------------------------------------------------------- +import os.path as osp + +import cv2 +import numpy as np + +from dust3r.datasets.co3d import Co3d +from dust3r.utils.image import imread_cv2 + + +class WildRGBD(Co3d): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = 'WildRGBD' + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'metadata', f'{view_idx:0>5d}.npz') + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'rgb', f'{view_idx:0>5d}.jpg') + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'depth', f'{view_idx:0>5d}.png') + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'masks', f'{view_idx:0>5d}.png') + + def _read_depthmap(self, depthpath, input_metadata): + # We store depths in the depth scale of 1000. + # That is, when we load depth image and divide by 1000, we could get depth in meters. + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000.0 + return depthmap + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = WildRGBD(split='train', ROOT="data/wildrgbd_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/monst3r/dust3r/demo.py b/monst3r/dust3r/demo.py new file mode 100644 index 0000000..70b1bf7 --- /dev/null +++ b/monst3r/dust3r/demo.py @@ -0,0 +1,450 @@ +# -------------------------------------------------------- +# gradio demo +# -------------------------------------------------------- + +import argparse +import math +import gradio +import os +import torch +import numpy as np +import tempfile +import functools +import copy + +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images, rgb, enlarge_seg_masks +from dust3r.utils.device import to_numpy +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.utils.viz_demo import convert_scene_output_to_glb, get_dynamic_mask_from_pairviewer +import matplotlib.pyplot as pl +import cv2 + + +# DEBUG +from dust3r.utils.viz_demo import render_colored_pointclouds, render_colored_pointclouds_pytorch3d +# DEBUG + +pl.ion() +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser_url = parser.add_mutually_exclusive_group() + parser_url.add_argument("--local_network", action='store_true', default=False, + help="make app accessible on local network: address will be set to 0.0.0.0") + parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1") + parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") + parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). " + "If None, will search for an available port starting at 7860."), + default=None) + parser.add_argument("--weights", type=str, help="path to the model weights", default='checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth') + parser.add_argument("--model_name", type=str, default='Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt', help="model name") + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--output_dir", type=str, default='./demo_tmp', help="value for tempfile.tempdir") + parser.add_argument("--silent", action='store_true', default=False, + help="silence logs") + parser.add_argument("--input_dir", type=str, help="Path to input images directory", default=None) + parser.add_argument("--seq_name", type=str, help="Sequence name for evaluation", default='NULL') + parser.add_argument('--use_gt_davis_masks', action='store_true', default=False, help='Use ground truth masks for DAVIS') + parser.add_argument('--not_batchify', action='store_true', default=False, help='Use non batchify mode for global optimization') + parser.add_argument('--real_time', action='store_true', default=False, help='Realtime mode') + parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') + + parser.add_argument('--fps', type=int, default=0, help='FPS for video processing') + parser.add_argument('--num_frames', type=int, default=200, help='Maximum number of frames for video processing') + + # Add "share" argument if you want to make the demo accessible on the public internet + parser.add_argument("--share", action='store_true', default=False, help="Share the demo") + return parser + +def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, + clean_depth=False, transparent_cams=False, cam_size=0.05, show_cam=True, save_name=None, thr_for_init_conf=True): + """ + extract 3D_model (glb file) from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + focals = scene.get_focals().cpu() + cams2world = scene.get_im_poses().cpu() + # 3D pointcloud from depthmap, poses and intrinsics + pts3d = to_numpy(scene.get_pts3d(raw_pts=True)) + + # DEBUG: save the pts3d frame by frame, later visualize with cloudcompare + # import trimesh + # for i in range(len(pts3d)): + # pc = trimesh.PointCloud(pts3d[i].reshape(-1, 3)) + # pc.export(f'{outdir}/pts3d_frame_{i}.ply') + + scene.min_conf_thr = min_conf_thr + scene.thr_for_init_conf = thr_for_init_conf + msk = to_numpy(scene.get_masks()) + cmap = pl.get_cmap('viridis') + cam_color = [cmap(i/len(rgbimg))[:3] for i in range(len(rgbimg))] + cam_color = [(255*c[0], 255*c[1], 255*c[2]) for c in cam_color] + + # DEBUG, render images + pp = scene.get_principal_points().cpu() + import pdb; pdb.set_trace() + render_colored_pointclouds(outdir, rgbimg[0], pts3d[0], msk[0], focals, pp, cams2world[0:1]) + render_colored_pointclouds_pytorch3d(outdir, rgbimg[0], pts3d[0], msk[0], focals[0], pp[0], cams2world[0:1]) + + return convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, + transparent_cams=transparent_cams, cam_size=cam_size, show_cam=show_cam, silent=silent, save_name=save_name, + cam_color=cam_color) + + +def get_reconstructed_scene(args, outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, show_cam, scenegraph_type, winsize, refid, + seq_name, new_model_weights, temporal_smoothing_weight, translation_weight, shared_focal, + flow_loss_weight, flow_loss_start_iter, flow_loss_threshold, use_gt_mask, fps, num_frames): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + translation_weight = float(translation_weight) + if new_model_weights != args.weights: + model = AsymmetricCroCo3DStereo.from_pretrained(new_model_weights).to(device) + model.eval() + if seq_name != "NULL": + dynamic_mask_path = f'data/davis/DAVIS/masked_images/480p/{seq_name}' + else: + dynamic_mask_path = None + imgs = load_images(filelist, size=image_size, verbose=not silent, dynamic_mask_root=dynamic_mask_path, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + if scenegraph_type == "swin" or scenegraph_type == "swinstride" or scenegraph_type == "swin2stride": + scenegraph_type = scenegraph_type + "-" + str(winsize) + "-noncyclic" + elif scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + if len(imgs) > 2: + mode = GlobalAlignerMode.PointCloudOptimizer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent, shared_focal = shared_focal, temporal_smoothing_weight=temporal_smoothing_weight, translation_weight=translation_weight, + flow_loss_weight=flow_loss_weight, flow_loss_start_epoch=flow_loss_start_iter, flow_loss_thre=flow_loss_threshold, use_self_mask=not use_gt_mask, + num_total_iter=niter, empty_cache= len(filelist) > 72, batchify=not args.not_batchify) + else: + mode = GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent) + lr = 0.01 + + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + outfile = get_3D_model_from_scene(save_folder, silent, scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam) + + poses = scene.save_tum_poses(f'{save_folder}/pred_traj.txt') + K = scene.save_intrinsics(f'{save_folder}/pred_intrinsics.txt') + depth_maps = scene.save_depth_maps(save_folder) + dynamic_masks = scene.save_dynamic_masks(save_folder) + conf = scene.save_conf_maps(save_folder) + init_conf = scene.save_init_conf_maps(save_folder) + rgbs = scene.save_rgb_imgs(save_folder) + enlarge_seg_masks(save_folder, kernel_size=5 if use_gt_mask else 3) + + # also return rgb, depth and confidence imgs + # depth is normalized with the max value for all images + # we apply the jet colormap on the confidence maps + rgbimg = scene.imgs + depths = to_numpy(scene.get_depthmaps()) + confs = to_numpy([c for c in scene.im_conf]) + init_confs = to_numpy([c for c in scene.init_conf_maps]) + cmap = pl.get_cmap('jet') + depths_max = max([d.max() for d in depths]) + depths = [cmap(d/depths_max) for d in depths] + confs_max = max([d.max() for d in confs]) + confs = [cmap(d/confs_max) for d in confs] + init_confs_max = max([d.max() for d in init_confs]) + init_confs = [cmap(d/init_confs_max) for d in init_confs] + + imgs = [] + for i in range(len(rgbimg)): + imgs.append(rgbimg[i]) + imgs.append(rgb(depths[i])) + imgs.append(rgb(confs[i])) + imgs.append(rgb(init_confs[i])) + + # if two images, and the shape is same, we can compute the dynamic mask + if len(rgbimg) == 2 and rgbimg[0].shape == rgbimg[1].shape: + motion_mask_thre = 0.35 + error_map = get_dynamic_mask_from_pairviewer(scene, both_directions=True, output_dir=args.output_dir, motion_mask_thre=motion_mask_thre) + # imgs.append(rgb(error_map)) + # apply threshold on the error map + normalized_error_map = (error_map - error_map.min()) / (error_map.max() - error_map.min()) + error_map_max = normalized_error_map.max() + error_map = cmap(normalized_error_map/error_map_max) + imgs.append(rgb(error_map)) + binary_error_map = (normalized_error_map > motion_mask_thre).astype(np.uint8) + imgs.append(rgb(binary_error_map*255)) + + return scene, outfile, imgs + + +def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type): + # if inputfiles[0] is a video, set the num_files to 200 + if inputfiles is not None and len(inputfiles) == 1 and inputfiles[0].name.endswith(('.mp4', '.avi', '.mov', '.MP4', '.AVI', '.MOV')): + num_files = 200 + else: + num_files = len(inputfiles) if inputfiles is not None else 1 + max_winsize = max(1, math.ceil((num_files-1)/2)) + if scenegraph_type == "swin" or scenegraph_type == "swin2stride" or scenegraph_type == "swinstride": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=min(max_winsize,5), + minimum=1, maximum=max_winsize, step=1, visible=True) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=False) + elif scenegraph_type == "oneref": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=True) + else: + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files-1, step=1, visible=False) + return winsize, refid + + +def get_reconstructed_scene_realtime(args, model, device, silent, image_size, filelist, scenegraph_type, refid, seq_name, fps, num_frames): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + model.eval() + imgs = load_images(filelist, size=image_size, verbose=not silent, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + + if scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + elif scenegraph_type == "oneref_mid": + scenegraph_type = "oneref-" + str(len(imgs) // 2) + else: + raise ValueError(f"Unknown scenegraph type for realtime mode: {scenegraph_type}") + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=False) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + + + view1, view2, pred1, pred2 = output['view1'], output['view2'], output['pred1'], output['pred2'] + pts1 = pred1['pts3d'].detach().cpu().numpy() + pts2 = pred2['pts3d_in_other_view'].detach().cpu().numpy() + for batch_idx in range(len(view1['img'])): + colors1 = rgb(view1['img'][batch_idx]) + colors2 = rgb(view2['img'][batch_idx]) + xyzrgb1 = np.concatenate([pts1[batch_idx], colors1], axis=-1) #(H, W, 6) + xyzrgb2 = np.concatenate([pts2[batch_idx], colors2], axis=-1) + np.save(save_folder + '/pts3d1_p' + str(batch_idx) + '.npy', xyzrgb1) + np.save(save_folder + '/pts3d2_p' + str(batch_idx) + '.npy', xyzrgb2) + + conf1 = pred1['conf'][batch_idx].detach().cpu().numpy() + conf2 = pred2['conf'][batch_idx].detach().cpu().numpy() + np.save(save_folder + '/conf1_p' + str(batch_idx) + '.npy', conf1) + np.save(save_folder + '/conf2_p' + str(batch_idx) + '.npy', conf2) + + # save the imgs of two views + img1_rgb = cv2.cvtColor(colors1 * 255, cv2.COLOR_BGR2RGB) + img2_rgb = cv2.cvtColor(colors2 * 255, cv2.COLOR_BGR2RGB) + cv2.imwrite(save_folder + '/img1_p' + str(batch_idx) + '.png', img1_rgb) + cv2.imwrite(save_folder + '/img2_p' + str(batch_idx) + '.png', img2_rgb) + + return save_folder + + +def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, args=None): + recon_fun = functools.partial(get_reconstructed_scene, args, tmpdirname, model, device, silent, image_size) + model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent) + with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="MonST3R Demo") as demo: + # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference + scene = gradio.State(None) + gradio.HTML(f'

MonST3R Demo

') + with gradio.Column(): + inputfiles = gradio.File(file_count="multiple") + with gradio.Row(): + schedule = gradio.Dropdown(["linear", "cosine"], + value='linear', label="schedule", info="For global alignment!") + niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000, + label="num_iterations", info="For global alignment!") + seq_name = gradio.Textbox(label="Sequence Name", placeholder="NULL", value=args.seq_name, info="For evaluation") + scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref", "swinstride", "swin2stride"], + value='swinstride', label="Scenegraph", + info="Define how to make pairs", + interactive=True) + winsize = gradio.Slider(label="Scene Graph: Window Size", value=5, + minimum=1, maximum=1, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) + + run_btn = gradio.Button("Run") + + with gradio.Row(): + # adjust the confidence thresholdx + min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.1, minimum=0.0, maximum=20, step=0.01) + # adjust the camera size in the output pointcloud + cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001) + # adjust the temporal smoothing weight + temporal_smoothing_weight = gradio.Slider(label="temporal_smoothing_weight", value=0.01, minimum=0.0, maximum=0.1, step=0.001) + # add translation weight + translation_weight = gradio.Textbox(label="translation_weight", placeholder="1.0", value="1.0", info="For evaluation") + # change to another model + new_model_weights = gradio.Textbox(label="New Model", placeholder=args.weights, value=args.weights, info="Path to updated model weights") + with gradio.Row(): + as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud") + # two post process implemented + mask_sky = gradio.Checkbox(value=False, label="Mask sky") + clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") + transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") + # not to show camera + show_cam = gradio.Checkbox(value=True, label="Show Camera") + shared_focal = gradio.Checkbox(value=True, label="Shared Focal Length") + use_davis_gt_mask = gradio.Checkbox(value=False, label="Use GT Mask (DAVIS)") + with gradio.Row(): + flow_loss_weight = gradio.Slider(label="Flow Loss Weight", value=0.01, minimum=0.0, maximum=1.0, step=0.001) + flow_loss_start_iter = gradio.Slider(label="Flow Loss Start Iter", value=0.1, minimum=0, maximum=1, step=0.01) + flow_loss_threshold = gradio.Slider(label="Flow Loss Threshold", value=25, minimum=0, maximum=100, step=1) + # for video processing + fps = gradio.Slider(label="FPS", value=0, minimum=0, maximum=60, step=1) + num_frames = gradio.Slider(label="Num Frames", value=100, minimum=0, maximum=200, step=1) + + outmodel = gradio.Model3D() + outgallery = gradio.Gallery(label='rgb,depth,confidence, init_conf', columns=4, height="100%") + + # events + scenegraph_type.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + inputfiles.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + run_btn.click(fn=recon_fun, + inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, + mask_sky, clean_depth, transparent_cams, cam_size, show_cam, + scenegraph_type, winsize, refid, seq_name, new_model_weights, + temporal_smoothing_weight, translation_weight, shared_focal, + flow_loss_weight, flow_loss_start_iter, flow_loss_threshold, use_davis_gt_mask, + fps, num_frames], + outputs=[scene, outmodel, outgallery]) + min_conf_thr.release(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + cam_size.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + as_pointcloud.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + mask_sky.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + clean_depth.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + transparent_cams.change(model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, show_cam], + outputs=outmodel) + demo.launch(share=args.share, server_name=server_name, server_port=server_port) + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + + if args.output_dir is not None: + tmp_path = args.output_dir + os.makedirs(tmp_path, exist_ok=True) + tempfile.tempdir = tmp_path + + if args.server_name is not None: + server_name = args.server_name + else: + server_name = '0.0.0.0' if args.local_network else '127.0.0.1' + + if args.weights is not None and os.path.exists(args.weights): + weights_path = args.weights + else: + weights_path = args.model_name + + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + # Use the provided output_dir or create a temporary directory + tmpdirname = args.output_dir if args.output_dir is not None else tempfile.mkdtemp(suffix='monst3r_gradio_demo') + + if not args.silent: + print('Outputting stuff in', tmpdirname) + + if args.input_dir is not None: + # Process images in the input directory with default parameters + if os.path.isdir(args.input_dir): # input_dir is a directory of images + input_files = [os.path.join(args.input_dir, fname) for fname in sorted(os.listdir(args.input_dir))] + else: # input_dir is a video + input_files = [args.input_dir] + + if args.real_time: + recon_fun = functools.partial(get_reconstructed_scene_realtime, args, model, args.device, args.silent, args.image_size) + outfile = recon_fun( + filelist=input_files, + scenegraph_type='oneref_mid', + refid=0, + seq_name=args.seq_name, + fps=args.fps, + num_frames=args.num_frames, + ) + else: + recon_fun = functools.partial(get_reconstructed_scene, args, tmpdirname, model, args.device, args.silent, args.image_size) + # Call the function with default parameters + scene, outfile, imgs = recon_fun( + filelist=input_files, + schedule='linear', + niter=300, + min_conf_thr=1.1, + as_pointcloud=True, + mask_sky=False, + clean_depth=True, + transparent_cams=False, + cam_size=0.05, + show_cam=True, + scenegraph_type='swinstride', + winsize=5, + refid=0, + seq_name=args.seq_name, + new_model_weights=args.weights, + temporal_smoothing_weight=0.01, + translation_weight='1.0', + shared_focal=True, + flow_loss_weight=0.01, + flow_loss_start_iter=0.1, + flow_loss_threshold=25, + use_gt_mask=args.use_gt_davis_masks, + fps=args.fps, + num_frames=args.num_frames, + ) + print(f"Processing completed. Output saved in {tmpdirname}/{args.seq_name}") + else: + # Launch Gradio demo + main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent, args=args) diff --git a/monst3r/dust3r/depth_eval.py b/monst3r/dust3r/depth_eval.py new file mode 100644 index 0000000..5e00f71 --- /dev/null +++ b/monst3r/dust3r/depth_eval.py @@ -0,0 +1,336 @@ +import torch +import numpy as np +import cv2 +import glob +from pathlib import Path +from tqdm import tqdm +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from dust3r.utils.image import load_images, rgb, enlarge_seg_masks +from copy import deepcopy +from scipy.optimize import minimize +import os +from collections import defaultdict +import dust3r.eval_metadata +from dust3r.eval_metadata import dataset_metadata + +def eval_mono_depth_estimation(args, model, device): + metadata = dataset_metadata.get(args.eval_dataset) + if metadata is None: + raise ValueError(f"Unknown dataset: {args.eval_dataset}") + + img_path = metadata.get('img_path') + if 'img_path_func' in metadata: + img_path = metadata['img_path_func'](args) + + process_func = metadata.get('process_func') + if process_func is None: + raise ValueError(f"No processing function defined for dataset: {args.eval_dataset}") + + for filelist, save_dir in process_func(args, img_path): + Path(save_dir).mkdir(parents=True, exist_ok=True) + eval_mono_depth(args, model, device, filelist, save_dir=save_dir) + + +def eval_mono_depth(args, model, device, filelist, save_dir=None): + model.eval() + load_img_size = 512 + for file in tqdm(filelist): + # construct the "image pair" for the single image + file = [file] + imgs = load_images(file, size=load_img_size, verbose=False, crop= not args.no_crop) + imgs = [imgs[0], deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + + pairs = make_pairs(imgs, symmetrize=True, prefilter=None) + output = inference(pairs, model, device, batch_size=1, verbose=False) + depth_map = output['pred1']['pts3d'][...,-1].mean(dim=0) + + if save_dir is not None: + #save the depth map to the save_dir as npy + np.save(f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.npy')}", depth_map.cpu().numpy()) + # also save the png + depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) + depth_map = (depth_map * 255).cpu().numpy().astype(np.uint8) + cv2.imwrite(f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.png')}", depth_map) + + + +## used for calculating the depth evaluation metrics + + +def group_by_directory(pathes, idx=-1): + """ + Groups the file paths based on the second-to-last directory in their paths. + + Parameters: + - pathes (list): List of file paths. + + Returns: + - dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths. + """ + grouped_pathes = defaultdict(list) + + for path in pathes: + # Extract the second-to-last directory + dir_name = os.path.dirname(path).split('/')[idx] + grouped_pathes[dir_name].append(path) + + return grouped_pathes + + +def depth2disparity(depth, return_mask=False): + if isinstance(depth, torch.Tensor): + disparity = torch.zeros_like(depth) + elif isinstance(depth, np.ndarray): + disparity = np.zeros_like(depth) + non_negtive_mask = depth > 0 + disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] + if return_mask: + return disparity, non_negtive_mask + else: + return disparity + +def absolute_error_loss(params, predicted_depth, ground_truth_depth): + s, t = params + + predicted_aligned = s * predicted_depth + t + + abs_error = np.abs(predicted_aligned - ground_truth_depth) + return np.sum(abs_error) + +def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0): + predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1) + ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1) + + initial_params = [s, t] # s = 1, t = 0 + + result = minimize(absolute_error_loss, initial_params, args=(predicted_depth_np, ground_truth_depth_np)) + + s, t = result.x + return s, t + +def absolute_value_scaling2(predicted_depth, ground_truth_depth, s_init=1.0, t_init=0.0, lr=1e-4, max_iters=1000, tol=1e-6): + # Initialize s and t as torch tensors with requires_grad=True + s = torch.tensor([s_init], requires_grad=True, device=predicted_depth.device, dtype=predicted_depth.dtype) + t = torch.tensor([t_init], requires_grad=True, device=predicted_depth.device, dtype=predicted_depth.dtype) + + optimizer = torch.optim.Adam([s, t], lr=lr) + + prev_loss = None + + for i in range(max_iters): + optimizer.zero_grad() + + # Compute predicted aligned depth + predicted_aligned = s * predicted_depth + t + + # Compute absolute error + abs_error = torch.abs(predicted_aligned - ground_truth_depth) + + # Compute loss + loss = torch.sum(abs_error) + + # Backpropagate + loss.backward() + + # Update parameters + optimizer.step() + + # Check convergence + if prev_loss is not None and torch.abs(prev_loss - loss) < tol: + break + + prev_loss = loss.item() + + return s.detach().item(), t.detach().item() + +def depth_evaluation(predicted_depth_original, ground_truth_depth_original, max_depth=80, custom_mask=None, post_clip_min=None, post_clip_max=None, pre_clip_min=None, pre_clip_max=None, + align_with_lstsq=False, align_with_lad=False, align_with_lad2=False, lr=1e-4, max_iters=1000, use_gpu=False, align_with_scale=False, + disp_input=False): + """ + Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment. + + Args: + predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map. + ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map. + max_depth (float): The maximum depth value to consider. Default is 80 meters. + align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth. + + Returns: + dict: A dictionary containing the evaluation metrics. + torch.Tensor: The depth error parity map. + """ + + if isinstance(predicted_depth_original, np.ndarray): + predicted_depth_original = torch.from_numpy(predicted_depth_original) + if isinstance(ground_truth_depth_original, np.ndarray): + ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original) + if custom_mask is not None and isinstance(custom_mask, np.ndarray): + custom_mask = torch.from_numpy(custom_mask) + + # if the dimension is 3, flatten to 2d along the batch dimension + if predicted_depth_original.dim() == 3: + _, h, w = predicted_depth_original.shape + predicted_depth_original = predicted_depth_original.view(-1, w) + ground_truth_depth_original = ground_truth_depth_original.view(-1, w) + if custom_mask is not None: + custom_mask = custom_mask.view(-1, w) + + # put to device + if use_gpu: + predicted_depth_original = predicted_depth_original.cuda() + ground_truth_depth_original = ground_truth_depth_original.cuda() + + # Filter out depths greater than max_depth + if max_depth is not None: + mask = (ground_truth_depth_original > 0) & (ground_truth_depth_original < max_depth) + else: + mask = (ground_truth_depth_original > 0) + + predicted_depth = predicted_depth_original[mask] + ground_truth_depth = ground_truth_depth_original[mask] + + # Clip the depth values + if pre_clip_min is not None: + predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min) + if pre_clip_max is not None: + predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max) + + if disp_input: # align the pred to gt in the disparity space + real_gt = ground_truth_depth.clone() + ground_truth_depth = 1 / (ground_truth_depth + 1e-8) + + # various alignment methods + if align_with_lstsq: + # Convert to numpy for lstsq + predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1) + ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1) + + # Add a column of ones for the shift term + A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)]) + + # Solve for scale (s) and shift (t) using least squares + result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None) + s, t = result[0][0], result[0][1] + + # convert to torch tensor + s = torch.tensor(s, device=predicted_depth_original.device) + t = torch.tensor(t, device=predicted_depth_original.device) + + # Apply scale and shift + predicted_depth = s * predicted_depth + t + elif align_with_lad: + s, t = absolute_value_scaling(predicted_depth, ground_truth_depth, s=torch.median(ground_truth_depth) / torch.median(predicted_depth)) + predicted_depth = s * predicted_depth + t + elif align_with_lad2: + s_init = (torch.median(ground_truth_depth) / torch.median(predicted_depth)).item() + s, t = absolute_value_scaling2(predicted_depth, ground_truth_depth, s_init=s_init, lr=lr, max_iters=max_iters) + predicted_depth = s * predicted_depth + t + elif align_with_scale: + # Compute initial scale factor 's' using the closed-form solution (L2 norm) + dot_pred_gt = torch.nanmean(ground_truth_depth) + dot_pred_pred = torch.nanmean(predicted_depth) + s = dot_pred_gt / dot_pred_pred + + # Iterative reweighted least squares using the Weiszfeld method + for _ in range(10): + # Compute residuals between scaled predictions and ground truth + residuals = s * predicted_depth - ground_truth_depth + abs_residuals = residuals.abs() + 1e-8 # Add small constant to avoid division by zero + + # Compute weights inversely proportional to the residuals + weights = 1.0 / abs_residuals + + # Update 's' using weighted sums + weighted_dot_pred_gt = torch.sum(weights * predicted_depth * ground_truth_depth) + weighted_dot_pred_pred = torch.sum(weights * predicted_depth ** 2) + s = weighted_dot_pred_gt / weighted_dot_pred_pred + + # Optionally clip 's' to prevent extreme scaling + s = s.clamp(min=1e-3) + + # Detach 's' if you want to stop gradients from flowing through it + s = s.detach() + + # Apply the scale factor to the predicted depth + predicted_depth = s * predicted_depth + + else: + # Align the predicted depth with the ground truth using median scaling + scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth) + predicted_depth *= scale_factor + + if disp_input: + # convert back to depth + ground_truth_depth = real_gt + predicted_depth = depth2disparity(predicted_depth) + + # Clip the predicted depth values + if post_clip_min is not None: + predicted_depth = torch.clamp(predicted_depth, min=post_clip_min) + if post_clip_max is not None: + predicted_depth = torch.clamp(predicted_depth, max=post_clip_max) + + if custom_mask is not None: + assert custom_mask.shape == ground_truth_depth_original.shape + mask_within_mask = custom_mask.cpu()[mask] + predicted_depth = predicted_depth[mask_within_mask] + ground_truth_depth = ground_truth_depth[mask_within_mask] + + # Calculate the metrics + abs_rel = torch.mean(torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth).item() + sq_rel = torch.mean(((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth).item() + + # Correct RMSE calculation + rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item() + + # Clip the depth values to avoid log(0) + predicted_depth = torch.clamp(predicted_depth, min=1e-5) + log_rmse = torch.sqrt(torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2)).item() + + # Calculate the accuracy thresholds + max_ratio = torch.maximum(predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth) + threshold_1 = torch.mean((max_ratio < 1.25).float()).item() + threshold_2 = torch.mean((max_ratio < 1.25 ** 2).float()).item() + threshold_3 = torch.mean((max_ratio < 1.25 ** 3).float()).item() + + # Compute the depth error parity map + if align_with_lstsq or align_with_lad or align_with_lad2: + predicted_depth_original = predicted_depth_original * s + t + if disp_input: predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = torch.abs(predicted_depth_original - ground_truth_depth_original) / ground_truth_depth_original + elif align_with_scale: + predicted_depth_original = predicted_depth_original * s + if disp_input: predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = torch.abs(predicted_depth_original - ground_truth_depth_original) / ground_truth_depth_original + else: + predicted_depth_original = predicted_depth_original * scale_factor + if disp_input: predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = torch.abs(predicted_depth_original - ground_truth_depth_original) / ground_truth_depth_original + + # Reshape the depth_error_parity_map back to the original image size + depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original) + depth_error_parity_map_full = torch.where(mask, depth_error_parity_map, depth_error_parity_map_full) + + predict_depth_map_full = predicted_depth_original + + gt_depth_map_full = torch.zeros_like(ground_truth_depth_original) + gt_depth_map_full = torch.where(mask, ground_truth_depth_original, gt_depth_map_full) + + num_valid_pixels = torch.sum(mask).item() if custom_mask is None else torch.sum(mask_within_mask).item() + if num_valid_pixels == 0: + abs_rel, sq_rel, rmse, log_rmse, threshold_1, threshold_2, threshold_3 = 0, 0, 0, 0, 0, 0, 0 + + results = { + 'Abs Rel': abs_rel, + 'Sq Rel': sq_rel, + 'RMSE': rmse, + 'Log RMSE': log_rmse, + 'δ < 1.25': threshold_1, + 'δ < 1.25^2': threshold_2, + 'δ < 1.25^3': threshold_3, + 'valid_pixels': num_valid_pixels + } + + return results, depth_error_parity_map_full, predict_depth_map_full, gt_depth_map_full diff --git a/monst3r/dust3r/eval_metadata.py b/monst3r/dust3r/eval_metadata.py new file mode 100644 index 0000000..39eea99 --- /dev/null +++ b/monst3r/dust3r/eval_metadata.py @@ -0,0 +1,132 @@ +import os +import glob +from tqdm import tqdm + +# Define the merged dataset metadata dictionary +dataset_metadata = { + 'davis': { + 'img_path': "data/davis/DAVIS/JPEGImages/480p", + 'mask_path': "data/davis/DAVIS/masked_images/480p", + 'dir_path_func': lambda img_path, seq: os.path.join(img_path, seq), + 'gt_traj_func': lambda img_path, anno_path, seq: None, + 'traj_format': None, + 'seq_list': None, + 'full_seq': True, + 'mask_path_seq_func': lambda mask_path, seq: os.path.join(mask_path, seq), + 'skip_condition': None, + 'process_func': None, # Not used in mono depth estimation + }, + 'kitti': { + 'img_path': "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path + 'mask_path': None, + 'dir_path_func': lambda img_path, seq: os.path.join(img_path, seq), + 'gt_traj_func': lambda img_path, anno_path, seq: None, + 'traj_format': None, + 'seq_list': None, + 'full_seq': True, + 'mask_path_seq_func': lambda mask_path, seq: None, + 'skip_condition': None, + 'process_func': lambda args, img_path: process_kitti(args, img_path), + }, + 'bonn': { + 'img_path': "data/bonn/rgbd_bonn_dataset", + 'mask_path': None, + 'dir_path_func': lambda img_path, seq: os.path.join(img_path, f'rgbd_bonn_{seq}', 'rgb_110'), + 'gt_traj_func': lambda img_path, anno_path, seq: os.path.join(img_path, f'rgbd_bonn_{seq}', 'groundtruth_110.txt'), + 'traj_format': 'tum', + 'seq_list': ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"], + 'full_seq': False, + 'mask_path_seq_func': lambda mask_path, seq: None, + 'skip_condition': None, + 'process_func': lambda args, img_path: process_bonn(args, img_path), + }, + 'nyu': { + 'img_path': "data/nyu-v2/val/nyu_images", + 'mask_path': None, + 'process_func': lambda args, img_path: process_nyu(args, img_path), + }, + 'scannet': { + 'img_path': "data/scannetv2", + 'mask_path': None, + 'dir_path_func': lambda img_path, seq: os.path.join(img_path, seq, 'color_90'), + 'gt_traj_func': lambda img_path, anno_path, seq: os.path.join(img_path, seq, 'pose_90.txt'), + 'traj_format': 'replica', + 'seq_list': None, + 'full_seq': True, + 'mask_path_seq_func': lambda mask_path, seq: None, + 'skip_condition': lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + 'process_func': lambda args, img_path: process_scannet(args, img_path), + }, + 'tum': { + 'img_path': "data/tum", + 'mask_path': None, + 'dir_path_func': lambda img_path, seq: os.path.join(img_path, seq, 'rgb_90'), + 'gt_traj_func': lambda img_path, anno_path, seq: os.path.join(img_path, seq, 'groundtruth_90.txt'), + 'traj_format': 'tum', + 'seq_list': None, + 'full_seq': True, + 'mask_path_seq_func': lambda mask_path, seq: None, + 'skip_condition': None, + 'process_func': None, + }, + 'sintel': { + 'img_path': "data/sintel/training/final", + 'anno_path': "data/sintel/training/camdata_left", + 'mask_path': None, + 'dir_path_func': lambda img_path, seq: os.path.join(img_path, seq), + 'gt_traj_func': lambda img_path, anno_path, seq: os.path.join(anno_path, seq), + 'traj_format': None, + 'seq_list': ["alley_2", "ambush_4", "ambush_5", "ambush_6", "cave_2", "cave_4", "market_2", + "market_5", "market_6", "shaman_3", "sleeping_1", "sleeping_2", "temple_2", "temple_3"], + 'full_seq': False, + 'mask_path_seq_func': lambda mask_path, seq: None, + 'skip_condition': None, + 'process_func': lambda args, img_path: process_sintel(args, img_path), + }, +} + +# Define processing functions for each dataset +def process_kitti(args, img_path): + for dir in tqdm(sorted(glob.glob(f'{img_path}/*'))): + filelist = sorted(glob.glob(f'{dir}/*.png')) + save_dir = f'{args.output_dir}/{os.path.basename(dir)}' + yield filelist, save_dir + +def process_bonn(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f'{img_path}/*/'))): + filelist = sorted(glob.glob(f'{dir}/rgb/*.png')) + save_dir = f'{args.output_dir}/{os.path.basename(os.path.dirname(dir))}' + yield filelist, save_dir + else: + seq_list = ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"] if args.seq_list is None else args.seq_list + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f'{img_path}/rgbd_bonn_{seq}/rgb_110/*.png')) + save_dir = f'{args.output_dir}/{seq}' + yield filelist, save_dir + +def process_nyu(args, img_path): + filelist = sorted(glob.glob(f'{img_path}/*.png')) + save_dir = f'{args.output_dir}' + yield filelist, save_dir + +def process_scannet(args, img_path): + seq_list = sorted(glob.glob(f'{img_path}/*')) + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f'{seq}/color_90/*.jpg')) + save_dir = f'{args.output_dir}/{os.path.basename(seq)}' + yield filelist, save_dir + +def process_sintel(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f'{img_path}/*/'))): + filelist = sorted(glob.glob(f'{dir}/*.png')) + save_dir = f'{args.output_dir}/{os.path.basename(os.path.dirname(dir))}' + yield filelist, save_dir + else: + seq_list = ["alley_2", "ambush_4", "ambush_5", "ambush_6", "cave_2", "cave_4", "market_2", + "market_5", "market_6", "shaman_3", "sleeping_1", "sleeping_2", "temple_2", "temple_3"] + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f'{img_path}/{seq}/*.png')) + save_dir = f'{args.output_dir}/{seq}' + yield filelist, save_dir \ No newline at end of file diff --git a/monst3r/dust3r/heads/__init__.py b/monst3r/dust3r/heads/__init__.py new file mode 100644 index 0000000..53d0aa5 --- /dev/null +++ b/monst3r/dust3r/heads/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# head factory +# -------------------------------------------------------- +from .linear_head import LinearPts3d +from .dpt_head import create_dpt_head + + +def head_factory(head_type, output_mode, net, has_conf=False): + """" build a prediction head for the decoder + """ + if head_type == 'linear' and output_mode == 'pts3d': + return LinearPts3d(net, has_conf) + elif head_type == 'dpt' and output_mode == 'pts3d': + return create_dpt_head(net, has_conf=has_conf) + else: + raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") diff --git a/monst3r/dust3r/heads/dpt_head.py b/monst3r/dust3r/heads/dpt_head.py new file mode 100644 index 0000000..b7bdc9f --- /dev/null +++ b/monst3r/dust3r/heads/dpt_head.py @@ -0,0 +1,115 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dpt head implementation for DUST3R +# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; +# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True +# the forward function also takes as input a dictionnary img_info with key "height" and "width" +# for PixelwiseTask, the output will be of dimension B x num_channels x H x W +# -------------------------------------------------------- +from einops import rearrange +from typing import List +import torch +import torch.nn as nn +from dust3r.heads.postprocess import postprocess +import dust3r.utils.path_to_croco # noqa: F401 +from models.dpt_block import DPTOutputAdapter # noqa + + +class DPTOutputAdapter_fix(DPTOutputAdapter): + """ + Adapt croco's DPTOutputAdapter implementation for dust3r: + remove duplicated weigths, and fix forward for dust3r + """ + + def init(self, dim_tokens_enc=768): + super().init(dim_tokens_enc) + # these are duplicated weights + del self.act_1_postprocess + del self.act_2_postprocess + del self.act_3_postprocess + del self.act_4_postprocess + + def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + # H, W = input_info['image_size'] + image_size = self.image_size if image_size is None else image_size + H, W = image_size + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for dust3r, can return 3D points + confidence for all pixels""" + + def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, + output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_layers = True # backbone needs to return all layers + self.postprocess = postprocess + self.depth_mode = depth_mode + self.conf_mode = conf_mode + + assert n_cls_token == 0, "Not implemented" + dpt_args = dict(output_width_ratio=output_width_ratio, + num_channels=num_channels, + **kwargs) + if hooks_idx is not None: + dpt_args.update(hooks=hooks_idx) + self.dpt = DPTOutputAdapter_fix(**dpt_args) + dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info[0], img_info[1])) + if self.postprocess: + out = self.postprocess(out, self.depth_mode, self.conf_mode) + return out + + +def create_dpt_head(net, has_conf=False): + """ + return PixelwiseTaskWithDPT for given net params + """ + assert net.dec_depth > 9 + l2 = net.dec_depth + feature_dim = 256 + last_dim = feature_dim//2 + out_nchan = 3 + ed = net.enc_embed_dim + dd = net.dec_embed_dim + return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=[0, l2*2//4, l2*3//4, l2], + dim_tokens=[ed, dd, dd, dd], + postprocess=postprocess, + depth_mode=net.depth_mode, + conf_mode=net.conf_mode, + head_type='regression') diff --git a/monst3r/dust3r/heads/linear_head.py b/monst3r/dust3r/heads/linear_head.py new file mode 100644 index 0000000..6b697f2 --- /dev/null +++ b/monst3r/dust3r/heads/linear_head.py @@ -0,0 +1,41 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# linear head implementation for DUST3R +# -------------------------------------------------------- +import torch.nn as nn +import torch.nn.functional as F +from dust3r.heads.postprocess import postprocess + + +class LinearPts3d (nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, net, has_conf=False): + super().__init__() + self.patch_size = net.patch_embed.patch_size[0] + self.depth_mode = net.depth_mode + self.conf_mode = net.conf_mode + self.has_conf = has_conf + + self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) + + def setup(self, croconet): + pass + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return postprocess(feat, self.depth_mode, self.conf_mode) diff --git a/monst3r/dust3r/heads/postprocess.py b/monst3r/dust3r/heads/postprocess.py new file mode 100644 index 0000000..cd68a90 --- /dev/null +++ b/monst3r/dust3r/heads/postprocess.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# post process function for all heads: extract 3D points/confidence from output +# -------------------------------------------------------- +import torch + + +def postprocess(out, depth_mode, conf_mode): + """ + extract 3D points/confidence from prediction head output + """ + fmap = out.permute(0, 2, 3, 1) # B,H,W,3 + res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) + + if conf_mode is not None: + res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) + return res + + +def reg_dense_depth(xyz, mode): + """ + extract 3D points from prediction head output + """ + mode, vmin, vmax = mode + + no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) + assert no_bounds + + if mode == 'linear': + if no_bounds: + return xyz # [-inf, +inf] + return xyz.clip(min=vmin, max=vmax) + + # distance to origin + d = xyz.norm(dim=-1, keepdim=True) + xyz = xyz / d.clip(min=1e-8) + + if mode == 'square': + return xyz * d.square() + + if mode == 'exp': + return xyz * torch.expm1(d) + + raise ValueError(f'bad {mode=}') + + +def reg_dense_conf(x, mode): + """ + extract confidence from prediction head output + """ + mode, vmin, vmax = mode + if mode == 'exp': + return vmin + x.exp().clip(max=vmax-vmin) + if mode == 'sigmoid': + return (vmax - vmin) * torch.sigmoid(x) + vmin + raise ValueError(f'bad {mode=}') diff --git a/monst3r/dust3r/image_pairs.py b/monst3r/dust3r/image_pairs.py new file mode 100644 index 0000000..2b2f4fb --- /dev/null +++ b/monst3r/dust3r/image_pairs.py @@ -0,0 +1,112 @@ +# -------------------------------------------------------- +# utilities needed to load image pairs +# -------------------------------------------------------- +import numpy as np +import torch + + +def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True, force_symmetrize=False): + pairs = [] + if scene_graph == 'complete': # complete graph + for i in range(len(imgs)): + for j in range(i): + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('swin'): + iscyclic = not scene_graph.endswith('noncyclic') + try: + winsize = int(scene_graph.split('-')[1]) + except Exception as e: + winsize = 3 + pairsid = set() + if scene_graph.startswith('swinstride'): + stride = 2 + elif scene_graph.startswith('swin2stride'): + stride = 3 + else: + stride = 1 + if scene_graph.startswith('swinskip_start'): + start = 2 + else: + start = 1 + for i in range(len(imgs)): + for j in range(start, stride*winsize + start, stride): + idx = (i + j) + if iscyclic: + idx = idx % len(imgs) # explicit loop closure + if idx >= len(imgs): + continue + pairsid.add((i, idx) if i < idx else (idx, i)) + for i, j in pairsid: + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('logwin'): + iscyclic = not scene_graph.endswith('noncyclic') + try: + winsize = int(scene_graph.split('-')[1]) + except Exception as e: + winsize = 3 + offsets = [2**i for i in range(winsize)] + pairsid = set() + for i in range(len(imgs)): + ixs_l = [i - off for off in offsets] + ixs_r = [i + off for off in offsets] + for j in ixs_l + ixs_r: + if iscyclic: + j = j % len(imgs) # Explicit loop closure + if j < 0 or j >= len(imgs) or j == i: + continue + pairsid.add((i, j) if i < j else (j, i)) + for i, j in pairsid: + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('oneref'): + refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 + for j in range(len(imgs)): + if j != refid: + pairs.append((imgs[refid], imgs[j])) + + if (symmetrize and not scene_graph.startswith('oneref') and not scene_graph.startswith('swin-1')) or len(imgs) == 2 or force_symmetrize: + pairs += [(img2, img1) for img1, img2 in pairs] + + # now, remove edges + if isinstance(prefilter, str) and prefilter.startswith('seq'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:])) + + if isinstance(prefilter, str) and prefilter.startswith('cyc'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) + + return pairs + + +def sel(x, kept): + if isinstance(x, dict): + return {k: sel(v, kept) for k, v in x.items()} + if isinstance(x, (torch.Tensor, np.ndarray)): + return x[kept] + if isinstance(x, (tuple, list)): + return type(x)([x[k] for k in kept]) + + +def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): + # number of images + n = max(max(e) for e in edges) + 1 + + kept = [] + for e, (i, j) in enumerate(edges): + dis = abs(i - j) + if cyclic: + dis = min(dis, abs(i + n - j), abs(i - n - j)) + if dis <= seq_dis_thr: + kept.append(e) + return kept + + +def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): + edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + return [pairs[i] for i in kept] + + +def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): + edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') + return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) diff --git a/monst3r/dust3r/inference.py b/monst3r/dust3r/inference.py new file mode 100644 index 0000000..a8ecbc4 --- /dev/null +++ b/monst3r/dust3r/inference.py @@ -0,0 +1,179 @@ +# -------------------------------------------------------- +# utilities needed for the inference +# -------------------------------------------------------- +import tqdm +import torch +from dust3r.utils.device import to_cpu, collate_with_cat +from dust3r.utils.misc import invalid_to_nans +from dust3r.utils.geometry import depthmap_to_pts3d, geotrf +from dust3r.viz import SceneViz, auto_cam_size +from dust3r.utils.image import rgb + +import numpy as np +import torch + +def _interleave_imgs(img1, img2): + res = {} + for key, value1 in img1.items(): + value2 = img2[key] + if isinstance(value1, torch.Tensor) and value1.ndim == value2.ndim: + value = torch.stack((value1, value2), dim=1).flatten(0, 1) + else: + value = [x for pair in zip(value1, value2) for x in pair] + res[key] = value + return res + + +def make_batch_symmetric(batch): + view1, view2 = batch + view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) + return view1, view2 + +def visualize_results(view1, view2, pred1, pred2, save_dir='./tmp', save_name=None, visualize_type='gt'): + # visualize_type: 'gt' or 'pred' + viz = SceneViz() + views = [view1, view2] + poses = [views[view_idx]['camera_pose'][0] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.5) + if visualize_type == 'pred': + cam_size *= 0.1 + views[0]['pts3d'] = geotrf(poses[0], pred1['pts3d']) # convert from X_camera1 to X_world + views[1]['pts3d'] = geotrf(poses[0], pred2['pts3d_in_other_view']) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'][0] + valid_mask = views[view_idx]['valid_mask'][0] + colors = rgb(views[view_idx]['img'][0]) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'][0], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(255, 0, 0), + image=colors, + cam_size=cam_size) + if save_name is None: + save_name = f'{views[0]["dataset"][0]}_{views[0]["label"][0]}_{views[0]["instance"][0]}_{views[1]["instance"][0]}_{visualize_type}' + save_path = save_dir+'/'+save_name+'.glb' + print(f'Saving visualization to {save_path}') + return viz.save_glb(save_path) + +def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): + view1, view2 = batch + ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng']) + for view in batch: + for name in view.keys(): # pseudo_focal + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + if symmetrize_batch: + view1, view2 = make_batch_symmetric(batch) + + with torch.cuda.amp.autocast(enabled=bool(use_amp)): + pred1, pred2 = model(view1, view2) + + # loss is supposed to be symmetric + with torch.cuda.amp.autocast(enabled=False): + loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None + + result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) + + return result[ret] if ret else result + + +@torch.no_grad() +def inference(pairs, model, device, batch_size=8, verbose=True): + if verbose: + print(f'>> Inference with model on {len(pairs)} image pairs') + result = [] + + # first, check if all images have the same size + multiple_shapes = not (check_if_same_size(pairs)) + if multiple_shapes: # force bs=1 + batch_size = 1 + + for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose): + + res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device) + + result.append(to_cpu(res)) + + result = collate_with_cat(result, lists=multiple_shapes) + + return result + + +def check_if_same_size(pairs): + shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs] + shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs] + return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2) + + +def get_pred_pts3d(gt, pred, use_pose=False): + if 'depth' in pred and 'pseudo_focal' in pred: + try: + pp = gt['camera_intrinsics'][..., :2, 2] + except KeyError: + pp = None + pts3d = depthmap_to_pts3d(**pred, pp=pp) + + elif 'pts3d' in pred: + # pts3d from my camera + pts3d = pred['pts3d'] + + elif 'pts3d_in_other_view' in pred: + # pts3d from the other camera, already transformed + assert use_pose is True + return pred['pts3d_in_other_view'] # return! + + if use_pose: + camera_pose = pred.get('camera_pose') + assert camera_pose is not None + pts3d = geotrf(camera_pose, pts3d) + + return pts3d + + +def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None): + assert gt_pts1.ndim == pr_pts1.ndim == 4 + assert gt_pts1.shape == pr_pts1.shape + if gt_pts2 is not None: + assert gt_pts2.ndim == pr_pts2.ndim == 4 + assert gt_pts2.shape == pr_pts2.shape + + # concat the pointcloud + nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) + nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None + + pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) + pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None + + all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1 + all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 + + dot_gt_pr = (all_pr * all_gt).sum(dim=-1) + dot_gt_gt = all_gt.square().sum(dim=-1) + + if fit_mode.startswith('avg'): + # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + elif fit_mode.startswith('median'): + scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values + elif fit_mode.startswith('weiszfeld'): + # init scaling with l2 closed form + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip_(min=1e-8).reciprocal() + # update the scaling with the new weights + scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) + else: + raise ValueError(f'bad {fit_mode=}') + + if fit_mode.endswith('stop_grad'): + scaling = scaling.detach() + + scaling = scaling.clip(min=1e-3) + # assert scaling.isfinite().all(), bb() + return scaling diff --git a/monst3r/dust3r/losses.py b/monst3r/dust3r/losses.py new file mode 100644 index 0000000..a8bd465 --- /dev/null +++ b/monst3r/dust3r/losses.py @@ -0,0 +1,299 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Implementation of DUSt3R training losses +# -------------------------------------------------------- +from copy import copy, deepcopy +import torch +import torch.nn as nn + +from dust3r.inference import get_pred_pts3d, find_opt_scaling +from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud +from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale + + +def Sum(*losses_and_masks): + loss, mask = losses_and_masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + return losses_and_masks + else: + # we are returning the global loss + for loss2, mask2 in losses_and_masks[1:]: + loss = loss + loss2 + return loss + + +class BaseCriterion(nn.Module): + def __init__(self, reduction='mean'): + super().__init__() + self.reduction = reduction + + +class LLoss (BaseCriterion): + """ L-norm loss + """ + + def forward(self, a, b): + assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' + dist = self.distance(a, b) + assert dist.ndim == a.ndim - 1 # one dimension less + if self.reduction == 'none': + return dist + if self.reduction == 'sum': + return dist.sum() + if self.reduction == 'mean': + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f'bad {self.reduction=} mode') + + def distance(self, a, b): + raise NotImplementedError() + + +class L21Loss (LLoss): + """ Euclidean distance between 3d points """ + + def distance(self, a, b): + return torch.norm(a - b, dim=-1) # normalized L2 distance + + +L21 = L21Loss() + + +class Criterion (nn.Module): + def __init__(self, criterion=None): + super().__init__() + assert isinstance(criterion, BaseCriterion), f'{criterion} is not a proper criterion!' + self.criterion = copy(criterion) + + def get_name(self): + return f'{type(self).__name__}({self.criterion})' + + def with_reduction(self, mode='none'): + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = mode # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss (nn.Module): + """ Easily combinable losses (also keep track of individual loss values): + loss = MyLoss1() + 0.1*MyLoss2() + Usage: + Inherit from this class and override get_name() and compute_loss() + """ + + def __init__(self): + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + def __mul__(self, alpha): + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + __rmul__ = __mul__ # same + + def __add__(self, loss2): + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + # find the end of the chain + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + name = self.get_name() + if self._alpha != 1: + name = f'{self._alpha:g}*{name}' + if self._loss2: + name = f'{name} + {self._loss2}' + return name + + def forward(self, *args, **kwargs): + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class Regr3D (Criterion, MultiLoss): + """ Ensure that all 3D points are correct. + Asymmetric loss: view1 is supposed to be the anchor. + + P1 = RT1 @ D1, point clouds at world coord_frame + P2 = RT2 @ D2 + loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1), I is identical matrix + loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) + = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) + """ + + def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False): + super().__init__(criterion) + self.norm_mode = norm_mode + self.gt_scale = gt_scale + + def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): + # everything is normalized w.r.t. camera of view1 + in_camera1 = inv(gt1['camera_pose']) + gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3 + gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3 + + valid1 = gt1['valid_mask'].clone() + valid2 = gt2['valid_mask'].clone() + + if dist_clip is not None: + # points that are too far-away == invalid + dis1 = gt_pts1.norm(dim=-1) # (B, H, W) + dis2 = gt_pts2.norm(dim=-1) # (B, H, W) + valid1 = valid1 & (dis1 <= dist_clip) + valid2 = valid2 & (dis2 <= dist_clip) + + pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) + pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True) + + # normalize 3d points + if self.norm_mode: + pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) + if self.norm_mode and not self.gt_scale: + gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) + + return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {} + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) + # loss on img1 side + l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) + # loss on gt2 side + l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2]) + self_name = type(self).__name__ + details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())} + return Sum((l1, mask1), (l2, mask2)), (details | monitoring) + + +class ConfLoss (MultiLoss): + """ Weighted regression by learned confidence. + Assuming the input pixel_loss is a pixel-level regression loss. + + Principle: + high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) + low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) + + alpha: hyperparameter + """ + + def __init__(self, pixel_loss, alpha=1): + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction('none') + + def get_name(self): + return f'ConfLoss({self.pixel_loss})' + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + # compute per-pixel loss + ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) + if loss1.numel() == 0: + print('NO VALID POINTS in img1', force=True) + if loss2.numel() == 0: + print('NO VALID POINTS in img2', force=True) + + # weight by confidence + conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) + conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) + conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 + conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 + + # average + nan protection (in case of no valid pixels at all) + conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 + conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 + + return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) + + +class Regr3D_ShiftInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute unnormalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # compute median depth + gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] + pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] + gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] + pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] + + # subtract the median depth + gt_z1 -= gt_shift_z + gt_z2 -= gt_shift_z + pred_z1 -= pred_shift_z + pred_z2 -= pred_shift_z + + # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach()) + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + if gt_scale == True: enforce the prediction to take the same scale than GT + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute depth-normalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # measure scene scale + _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) + _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) + + # prevent predictions to be in a ridiculous range + pred_scale = pred_scale.clip(min=1e-3, max=1e3) + + # subtract the median depth + if self.gt_scale: + pred_pts1 *= gt_scale / pred_scale + pred_pts2 *= gt_scale / pred_scale + # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean()) + else: + gt_pts1 /= gt_scale + gt_pts2 /= gt_scale + pred_pts1 /= pred_scale + pred_pts2 /= pred_scale + # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()) + + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): + # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv + pass diff --git a/monst3r/dust3r/model.py b/monst3r/dust3r/model.py new file mode 100644 index 0000000..7ee4b7b --- /dev/null +++ b/monst3r/dust3r/model.py @@ -0,0 +1,214 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R model class +# -------------------------------------------------------- +from copy import deepcopy +import torch +import os +from packaging import version +import huggingface_hub + +from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape +from .heads import head_factory +from dust3r.patch_embed import get_patch_embed, ManyAR_PatchEmbed +from third_party.raft import load_RAFT + +import dust3r.utils.path_to_croco # noqa: F401 +from models.croco import CroCoNet # noqa + +inf = float('inf') + +hf_version_number = huggingface_hub.__version__ +assert version.parse(hf_version_number) >= version.parse("0.22.0"), "Outdated huggingface_hub version, please reinstall requirements.txt" + +def load_model(model_path, device, verbose=True): + if verbose: + print('... loading model from', model_path) + ckpt = torch.load(model_path, map_location='cpu', weights_only=False) + args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if 'landscape_only' not in args: + args = args[:-1] + ', landscape_only=False)' + else: + args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') + assert "landscape_only=False" in args + if verbose: + print(f"instantiating : {args}") + net = eval(args) + s = net.load_state_dict(ckpt['model'], strict=False) + if verbose: + print(s) + return net.to(device) + + +class AsymmetricCroCo3DStereo ( + CroCoNet, + huggingface_hub.PyTorchModelHubMixin, + library_name="dust3r", + repo_url="https://github.com/junyi/monst3r", + tags=["image-to-3d"], +): + """ Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__(self, + output_mode='pts3d', + head_type='linear', + depth_mode=('exp', -inf, inf), + conf_mode=('exp', 1, inf), + freeze='none', + landscape_only=True, + patch_embed_cls='PatchEmbedDust3R', # PatchEmbedDust3R or ManyAR_PatchEmbed + **croco_kwargs): + self.patch_embed_cls = patch_embed_cls + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + + # dust3r specific initialization + self.dec_blocks2 = deepcopy(self.dec_blocks) + self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) + self.set_freeze(freeze) + + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device='cpu') + else: + return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) + + def load_state_dict(self, ckpt, **kw): + # duplicate all weights for the second decoder if not present + new_ckpt = dict(ckpt) + if not any(k.startswith('dec_blocks2') for k in ckpt): + for key, value in ckpt.items(): + if key.startswith('dec_blocks'): + new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value + return super().load_state_dict(new_ckpt, **kw) + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + 'none': [], + 'mask': [self.mask_token], + 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], + 'encoder_and_decoder': [self.mask_token, self.patch_embed, self.enc_blocks, self.dec_blocks, self.dec_blocks2], + } + freeze_all_params(to_be_frozen[freeze]) + print(f'Freezing {freeze} parameters') + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head """ + return + + def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, + **kw): + if type(img_size) is int: + img_size = (img_size, img_size) + assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ + f'{img_size=} must be multiple of {patch_size=}' + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + # allocate heads + self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + # magic wrapper + self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) + self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) + + def _encode_image(self, image, true_shape): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + # x (B, 576, 1024) pos (B, 576, 2); patch_size=16 + B,N,C = x.size() + posvis = pos + # add positional embedding without cls token + assert self.enc_pos_embed is None + # TODO: where to add mask for the patches + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, posvis) + + x = self.enc_norm(x) + return x, pos, None + + def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): + if img1.shape[-2:] == img2.shape[-2:]: + out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), + torch.cat((true_shape1, true_shape2), dim=0)) + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + else: + out, pos, _ = self._encode_image(img1, true_shape1) + out2, pos2, _ = self._encode_image(img2, true_shape2) + return out, out2, pos, pos2 + + def _encode_symmetrized(self, view1, view2): + img1 = view1['img'] + img2 = view2['img'] + B = img1.shape[0] + + # Recover true_shape when available, otherwise assume that the img shape is the true one + shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) + shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) + + # warning! maybe the images have different portrait/landscape orientations + if is_symmetrized(view1, view2): + # computing half of forward pass!' + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2]) + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2): + final_output = [(f1, f2)] # before projection + original_D = f1.shape[-1] + + # project to decoder dim + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + final_output.append((f1, f2)) + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + # img1 side + f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) + # img2 side + f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) + # store the result + final_output.append((f1, f2)) + + # normalize last output + del final_output[1] # duplicate with final_output[0] + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + # img_shape = tuple(map(int, img_shape)) + head = getattr(self, f'head{head_num}') + return head(decout, img_shape) + + def forward(self, view1, view2): + # encode the two images --> B,S,D + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2) + + # combine all ref images into object-centric representation + dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) + + with torch.cuda.amp.autocast(enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2['pts3d_in_other_view'] = res2.pop('pts3d') # predict view2's pts3d in view1's frame + return res1, res2 diff --git a/monst3r/dust3r/optim_factory.py b/monst3r/dust3r/optim_factory.py new file mode 100644 index 0000000..9b9c16e --- /dev/null +++ b/monst3r/dust3r/optim_factory.py @@ -0,0 +1,14 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# optimization functions +# -------------------------------------------------------- + + +def adjust_learning_rate_by_lr(optimizer, lr): + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr diff --git a/monst3r/dust3r/patch_embed.py b/monst3r/dust3r/patch_embed.py new file mode 100644 index 0000000..e8d6c5d --- /dev/null +++ b/monst3r/dust3r/patch_embed.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# PatchEmbed implementation for DUST3R, +# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio +# -------------------------------------------------------- +import torch +import dust3r.utils.path_to_croco # noqa: F401 +from models.blocks import PatchEmbed # noqa + + +def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): + assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] + patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) + return patch_embed + + +class PatchEmbedDust3R(PatchEmbed): + def forward(self, x, **kw): + B, C, H, W = x.shape + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + +class ManyAR_PatchEmbed (PatchEmbed): + """ Handle images with non-square aspect ratio. + All images in the same batch have the same aspect ratio. + true_shape = [(height, width) ...] indicates the actual shape of each image. + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, init='xavier'): + self.embed_dim = embed_dim + super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten, init) + + def forward(self, img, true_shape): + B, C, H, W = img.shape + assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" + + # size expressed in tokens + W //= self.patch_size[0] + H //= self.patch_size[1] + n_tokens = H * W + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # allocate result + x = img.new_zeros((B, n_tokens, self.embed_dim)) + pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) + + # linear projection, transposed if necessary + x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() + x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() + + pos[is_landscape] = self.position_getter(1, H, W, pos.device) + pos[is_portrait] = self.position_getter(1, W, H, pos.device) + + x = self.norm(x) + return x, pos diff --git a/monst3r/dust3r/pose_eval.py b/monst3r/dust3r/pose_eval.py new file mode 100644 index 0000000..8036f80 --- /dev/null +++ b/monst3r/dust3r/pose_eval.py @@ -0,0 +1,239 @@ +import os +import math +import cv2 +import numpy as np +import torch +from dust3r.utils.vo_eval import load_traj, eval_metrics, plot_trajectory, save_trajectory_tum_format, process_directory, calculate_averages +import croco.utils.misc as misc +import torch.distributed as dist +from tqdm import tqdm +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.utils.image import load_images, rgb, enlarge_seg_masks +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from dust3r.demo import get_3D_model_from_scene +import dust3r.eval_metadata +from dust3r.eval_metadata import dataset_metadata + +def eval_pose_estimation(args, model, device, save_dir=None): + metadata = dataset_metadata.get(args.eval_dataset, dataset_metadata['sintel']) + img_path = metadata['img_path'] + mask_path = metadata['mask_path'] + + ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation_dist( + args, model, device, save_dir=save_dir, img_path=img_path, mask_path=mask_path + ) + return ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug + +def eval_pose_estimation_dist(args, model, device, img_path, save_dir=None, mask_path=None): + + metadata = dataset_metadata.get(args.eval_dataset, dataset_metadata['sintel']) + anno_path = metadata.get('anno_path', None) + + silent = args.silent + seq_list = args.seq_list + if seq_list is None: + if metadata.get('full_seq', False): + args.full_seq = True + else: + seq_list = metadata.get('seq_list', []) + if args.full_seq: + seq_list = os.listdir(img_path) + seq_list = [seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq))] + seq_list = sorted(seq_list) + + if save_dir is None: + save_dir = args.output_dir + + # Split seq_list across processes + if misc.is_dist_avail_and_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + total_seqs = len(seq_list) + seqs_per_proc = (total_seqs + world_size - 1) // world_size # Ceiling division + + start_idx = rank * seqs_per_proc + end_idx = min(start_idx + seqs_per_proc, total_seqs) + + seq_list = seq_list[start_idx:end_idx] + + ate_list = [] + rpe_trans_list = [] + rpe_rot_list = [] + outfile_list = [] + load_img_size = 512 + + error_log_path = f'{save_dir}/_error_log_{rank}.txt' # Unique log file per process + bug = False + + for seq in tqdm(seq_list): + try: + dir_path = metadata['dir_path_func'](img_path, seq) + + # Handle skip_condition + skip_condition = metadata.get('skip_condition', None) + if skip_condition is not None and skip_condition(save_dir, seq): + continue + + mask_path_seq_func = metadata.get('mask_path_seq_func', lambda mask_path, seq: None) + mask_path_seq = mask_path_seq_func(mask_path, seq) + + filelist = [os.path.join(dir_path, name) for name in os.listdir(dir_path)] + filelist.sort() + filelist = filelist[::args.pose_eval_stride] + max_winsize = max(1, math.ceil((len(filelist)-1)/2)) + scene_graph_type = args.scene_graph_type + if int(scene_graph_type.split('-')[1]) > max_winsize: + scene_graph_type = f'{args.scene_graph_type.split("-")[0]}-{max_winsize}' + if len(scene_graph_type.split("-")) > 2: + scene_graph_type += f'-{args.scene_graph_type.split("-")[2]}' + imgs = load_images( + filelist, size=load_img_size, verbose=False, + dynamic_mask_root=mask_path_seq, crop=not args.no_crop + ) + if args.eval_dataset == 'davis' and len(imgs) > 95: + # use swinstride-4 + scene_graph_type = scene_graph_type.replace('5', '4') + pairs = make_pairs( + imgs, scene_graph=scene_graph_type, prefilter=None, symmetrize=True + ) + + output = inference(pairs, model, device, batch_size=1, verbose=not silent) + + with torch.enable_grad(): + if len(imgs) > 2: + mode = GlobalAlignerMode.PointCloudOptimizer + scene = global_aligner( + output, device=device, mode=mode, verbose=not silent, + shared_focal=not args.not_shared_focal and not args.use_gt_focal, + flow_loss_weight=args.flow_loss_weight, flow_loss_fn=args.flow_loss_fn, + depth_regularize_weight=args.depth_regularize_weight, + num_total_iter=args.n_iter, temporal_smoothing_weight=args.temporal_smoothing_weight, motion_mask_thre=args.motion_mask_thre, + flow_loss_start_epoch=args.flow_loss_start_epoch, flow_loss_thre=args.flow_loss_thre, translation_weight=args.translation_weight, + sintel_ckpt=args.eval_dataset == 'sintel', use_self_mask=not args.use_gt_mask, sam2_mask_refine=args.sam2_mask_refine, + empty_cache=len(imgs) >= 80 and len(pairs) > 600, pxl_thre=args.pxl_thresh, # empty cache to make it run on 48GB GPU + ) + if args.use_gt_focal: + focal_path = os.path.join( + img_path.replace('final', 'camdata_left'), seq, 'focal.txt' + ) + focals = np.loadtxt(focal_path) + focals = focals[::args.pose_eval_stride] + original_img_size = cv2.imread(filelist[0]).shape[:2] + resized_img_size = tuple(imgs[0]['img'].shape[-2:]) + focals = focals * max( + (resized_img_size[0] / original_img_size[0]), + (resized_img_size[1] / original_img_size[1]) + ) + scene.preset_focal(focals, requires_grad=False) # TODO: requires_grad=False + lr = 0.01 + loss = scene.compute_global_alignment( + init='mst', niter=args.n_iter, schedule=args.pose_schedule, lr=lr, + ) + else: + mode = GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent) + + if args.save_pose_qualitative: + outfile = get_3D_model_from_scene( + outdir=save_dir, silent=silent, scene=scene, min_conf_thr=2, as_pointcloud=True, mask_sky=False, + clean_depth=True, transparent_cams=False, cam_size=0.01, save_name=seq + ) + else: + outfile = None + pred_traj = scene.get_tum_poses() + + os.makedirs(f'{save_dir}/{seq}', exist_ok=True) + scene.clean_pointcloud() + scene.save_tum_poses(f'{save_dir}/{seq}/pred_traj.txt') + scene.save_focals(f'{save_dir}/{seq}/pred_focal.txt') + scene.save_intrinsics(f'{save_dir}/{seq}/pred_intrinsics.txt') + scene.save_depth_maps(f'{save_dir}/{seq}') + scene.save_dynamic_masks(f'{save_dir}/{seq}') + scene.save_conf_maps(f'{save_dir}/{seq}') + scene.save_init_conf_maps(f'{save_dir}/{seq}') + scene.save_rgb_imgs(f'{save_dir}/{seq}') + enlarge_seg_masks(f'{save_dir}/{seq}', kernel_size=5 if args.use_gt_mask else 3) + + gt_traj_file = metadata['gt_traj_func'](img_path, anno_path, seq) + traj_format = metadata.get('traj_format', None) + + if args.eval_dataset == 'sintel': + gt_traj = load_traj(gt_traj_file=gt_traj_file, stride=args.pose_eval_stride) + elif traj_format is not None: + gt_traj = load_traj(gt_traj_file=gt_traj_file, traj_format=traj_format) + else: + gt_traj = None + + if gt_traj is not None: + ate, rpe_trans, rpe_rot = eval_metrics( + pred_traj, gt_traj, seq=seq, filename=f'{save_dir}/{seq}_eval_metric.txt' + ) + plot_trajectory( + pred_traj, gt_traj, title=seq, filename=f'{save_dir}/{seq}.png' + ) + else: + ate, rpe_trans, rpe_rot = 0, 0, 0 + outfile = None + bug = True + + ate_list.append(ate) + rpe_trans_list.append(rpe_trans) + rpe_rot_list.append(rpe_rot) + outfile_list.append(outfile) + + # Write to error log after each sequence + with open(error_log_path, 'a') as f: + f.write(f'{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n') + f.write(f'{ate:.5f}\n') + f.write(f'{rpe_trans:.5f}\n') + f.write(f'{rpe_rot:.5f}\n') + + except Exception as e: + if 'out of memory' in str(e): + # Handle OOM + torch.cuda.empty_cache() # Clear the CUDA memory + with open(error_log_path, 'a') as f: + f.write(f'OOM error in sequence {seq}, skipping this sequence.\n') + print(f'OOM error in sequence {seq}, skipping...') + elif 'Degenerate covariance rank' in str(e) or 'Eigenvalues did not converge' in str(e): + # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception + with open(error_log_path, 'a') as f: + f.write(f'Exception in sequence {seq}: {str(e)}\n') + print(f'Traj evaluation error in sequence {seq}, skipping.') + else: + raise e # Rethrow if it's not an expected exception + + # Aggregate results across all processes + if misc.is_dist_avail_and_initialized(): + torch.distributed.barrier() + + bug_tensor = torch.tensor(int(bug), device=device) + + bug = bool(bug_tensor.item()) + + # Handle outfile_list + outfile_list_all = [None for _ in range(world_size)] + + outfile_list_combined = [] + for sublist in outfile_list_all: + if sublist is not None: + outfile_list_combined.extend(sublist) + + results = process_directory(save_dir) + avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results) + + # Write the averages to the error log (only on the main process) + if rank == 0: + with open(f'{save_dir}/_error_log.txt', 'a') as f: + # Copy the error log from each process to the main error log + for i in range(world_size): + with open(f'{save_dir}/_error_log_{i}.txt', 'r') as f_sub: + f.write(f_sub.read()) + f.write(f'Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n') + + return avg_ate, avg_rpe_trans, avg_rpe_rot, outfile_list_combined, bug diff --git a/monst3r/dust3r/post_process.py b/monst3r/dust3r/post_process.py new file mode 100644 index 0000000..550a9b4 --- /dev/null +++ b/monst3r/dust3r/post_process.py @@ -0,0 +1,60 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities for interpreting the DUST3R output +# -------------------------------------------------------- +import numpy as np +import torch +from dust3r.utils.geometry import xy_grid + + +def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf): + """ Reprojection method, for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + # centered pixel grid + pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 + pts3d = pts3d.flatten(1, 2) # (B, HW, 3) + + if focal_mode == 'median': + with torch.no_grad(): + # direct estimation of focal + u, v = pixels.unbind(dim=-1) + x, y, z = pts3d.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + # assume square pixels, hence same focal for X and Y + f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) + focal = torch.nanmedian(f_votes, dim=-1).values + + elif focal_mode == 'weiszfeld': + # init focal with l2 closed form + # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| + xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) + + dot_xy_px = (xy_over_z * pixels).sum(dim=-1) + dot_xy_xy = xy_over_z.square().sum(dim=-1) + + focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) + + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip(min=1e-8).reciprocal() + # update the scaling with the new weights + focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) + else: + raise ValueError(f'bad {focal_mode=}') + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) + # print(focal) + return focal diff --git a/monst3r/dust3r/training.py b/monst3r/dust3r/training.py new file mode 100644 index 0000000..686b710 --- /dev/null +++ b/monst3r/dust3r/training.py @@ -0,0 +1,521 @@ +# -------------------------------------------------------- +# training code for DUSt3R +# -------------------------------------------------------- +import os +os.environ['OMP_NUM_THREADS'] = '4' # will affect the performance of pairwise prediction +import argparse +import datetime +import json +import numpy as np +import sys +import time +import math +import wandb +from collections import defaultdict +from pathlib import Path +from typing import Sized + +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model +from dust3r.datasets import get_data_loader # noqa +from dust3r.losses import * # noqa: F401, needed when loading the model +from dust3r.inference import loss_of_one_batch, visualize_results # noqa + +from dust3r.pose_eval import eval_pose_estimation +from dust3r.depth_eval import eval_mono_depth_estimation + +# from demo import get_3D_model_from_scene +import dust3r.utils.path_to_croco # noqa: F401 +import croco.utils.misc as misc # noqa +from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa + + +def get_args_parser(): + parser = argparse.ArgumentParser('DUST3R training', add_help=False) + # model and criterion + parser.add_argument('--model', default="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', \ + img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \ + enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, freeze='encoder')", + type=str, help="string containing the model to build") + parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint') + parser.add_argument('--train_criterion', default="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)", + type=str, help="train criterion") + parser.add_argument('--test_criterion', default=None, type=str, help="test criterion") + + # dataset + parser.add_argument('--train_dataset', default='[None]', type=str, help="training set") + parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set") + + # training + parser.add_argument('--seed', default=0, type=int, help="Random seed") + parser.add_argument('--batch_size', default=64, type=int, + help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") + parser.add_argument('--accum_iter', default=1, type=int, + help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") + parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") + parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") + parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', + help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--min_lr', type=float, default=0., metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') + parser.add_argument('--amp', type=int, default=0, + choices=[0, 1], help="Use Automatic Mixed Precision for pretraining") + parser.add_argument("--cudnn_benchmark", action='store_true', default=False, + help="set cudnn.benchmark = False") + parser.add_argument("--eval_only", action='store_true', default=False) + parser.add_argument("--fixed_eval_set", action='store_true', default=False) + parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint (default: none)') + + # others + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency') + parser.add_argument('--save_freq', default=1, type=int, + help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') + parser.add_argument('--keep_freq', default=5, type=int, + help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') + parser.add_argument('--print_freq', default=20, type=int, + help='frequence (number of iterations) to print infos while training') + parser.add_argument('--wandb', action='store_true', default=False, help='use wandb for logging') + parser.add_argument('--num_save_visual', default=1, type=int, help='number of visualizations to save') + + # switch mode for train / eval pose / eval depth + parser.add_argument('--mode', default='train', type=str, help='train / eval_pose / eval_depth') + + # for pose eval + parser.add_argument('--pose_eval_freq', default=0, type=int, help='pose evaluation frequency') + parser.add_argument('--pose_eval_stride', default=1, type=int, help='stride for pose evaluation') + parser.add_argument('--scene_graph_type', default='swinstride-5-noncyclic', type=str, help='scene graph window size') + parser.add_argument('--save_best_pose', action='store_true', default=False, help='save best pose') + parser.add_argument('--n_iter', default=300, type=int, help='number of iterations for pose optimization') + parser.add_argument('--save_pose_qualitative', action='store_true', default=False, help='save qualitative pose results') + parser.add_argument('--temporal_smoothing_weight', default=0.01, type=float, help='temporal smoothing weight for pose optimization') + parser.add_argument('--not_shared_focal', action='store_true', default=False, help='use shared focal length for pose optimization') + parser.add_argument('--use_gt_focal', action='store_true', default=False, help='use ground truth focal length for pose optimization') + parser.add_argument('--pose_schedule', default='linear', type=str, help='pose optimization schedule') + + parser.add_argument('--flow_loss_weight', default=0.01, type=float, help='flow loss weight for pose optimization') + parser.add_argument('--flow_loss_fn', default='smooth_l1', type=str, help='flow loss type for pose optimization') + parser.add_argument('--use_gt_mask', action='store_true', default=False, help='use gt mask for pose optimization, for sintel/davis') + parser.add_argument('--motion_mask_thre', default=0.35, type=float, help='motion mask threshold for pose optimization') + parser.add_argument('--sam2_mask_refine', action='store_true', default=False, help='use sam2 mask refine for the motion for pose optimization') + parser.add_argument('--flow_loss_start_epoch', default=0.1, type=float, help='start epoch for flow loss') + parser.add_argument('--flow_loss_thre', default=20, type=float, help='threshold for flow loss') + parser.add_argument('--pxl_thresh', default=50.0, type=float, help='threshold for flow loss') + parser.add_argument('--depth_regularize_weight', default=0.0, type=float, help='depth regularization weight for pose optimization') + parser.add_argument('--translation_weight', default=1, type=float, help='translation weight for pose optimization') + parser.add_argument('--silent', action='store_true', default=False, help='silent mode for pose evaluation') + parser.add_argument('--full_seq', action='store_true', default=False, help='use full sequence for pose evaluation') + parser.add_argument('--seq_list', nargs='+', default=None, help='list of sequences for pose evaluation') + + parser.add_argument('--eval_dataset', type=str, default='sintel', + choices=['davis', 'kitti', 'bonn', 'scannet', 'tum', 'nyu', 'sintel'], + help='choose dataset for pose evaluation') + + # for monocular depth eval + parser.add_argument('--no_crop', action='store_true', default=False, help='do not crop the image for monocular depth evaluation') + + # output dir + parser.add_argument('--output_dir', default='./results/tmp', type=str, help="path where to save the output") + return parser + +def load_model(args, device): + # model + print('Loading model: {:s}'.format(args.model)) + model = eval(args.model) + model.to(device) + model_without_ddp = model + if args.pretrained and not args.resume: + print('Loading pretrained: ', args.pretrained) + ckpt = torch.load(args.pretrained, map_location=device) + print(model.load_state_dict(ckpt['model'], strict=False)) + del ckpt # in case it occupies memory + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True) + model_without_ddp = model.module + + return model, model_without_ddp + +def train(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + # if main process, init wandb + if args.wandb and misc.is_main_process(): + wandb.init(name=args.output_dir.split('/')[-1], + project='dust3r', + config=args, + sync_tensorboard=True, + dir=args.output_dir) + + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # auto resume if not specified + if args.resume is None: + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + if os.path.isfile(last_ckpt_fname) and (not args.eval_only): args.resume = last_ckpt_fname + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = args.cudnn_benchmark + + model, model_without_ddp = load_model(args, device) + + # training dataset and loader + print('Building train dataset {:s}'.format(args.train_dataset)) + # dataset and loader + data_loader_train = build_dataset(args.train_dataset, args.batch_size, args.num_workers, test=False) + print('Building test dataset {:s}'.format(args.train_dataset)) + data_loader_test = {} + for dataset in args.test_dataset.split('+'): + testset = build_dataset(dataset, args.batch_size, args.num_workers, test=True) + name_testset = dataset.split('(')[0] + if getattr(testset.dataset.dataset, 'strides', None) is not None: + name_testset += f'_stride{testset.dataset.dataset.strides}' + data_loader_test[name_testset] = testset + + print(f'>> Creating train criterion = {args.train_criterion}') + train_criterion = eval(args.train_criterion).to(device) + print(f'>> Creating test criterion = {args.test_criterion or args.train_criterion}') + test_criterion = eval(args.test_criterion or args.criterion).to(device) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + # print(optimizer) + loss_scaler = NativeScaler() + + def write_log_stats(epoch, train_stats, test_stats): + if misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + gathered_test_stats = {} + log_stats = dict(epoch=epoch, **{f'train_{k}': v for k, v in train_stats.items()}) + + for test_name, testset in data_loader_test.items(): + + if test_name not in test_stats: + continue + + if getattr(testset.dataset.dataset, 'strides', None) is not None: + original_test_name = test_name.split('_stride')[0] + if original_test_name not in gathered_test_stats.keys(): + gathered_test_stats[original_test_name] = [] + gathered_test_stats[original_test_name].append(test_stats[test_name]) + + log_stats.update({test_name + '_' + k: v for k, v in test_stats[test_name].items()}) + + if len(gathered_test_stats) > 0: + for original_test_name, stride_stats in gathered_test_stats.items(): + if len(stride_stats) > 1: + stride_stats = {k: np.mean([x[k] for x in stride_stats]) for k in stride_stats[0]} + log_stats.update({original_test_name + '_stride_mean_' + k: v for k, v in stride_stats.items()}) + if args.wandb: + log_dict = {original_test_name + '_stride_mean_' + k: v for k, v in stride_stats.items()} + log_dict.update({'epoch': epoch}) + wandb.log(log_dict) + + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + def save_model(epoch, fname, best_so_far, best_pose_ate_sofar=None): + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, fname=fname, best_so_far=best_so_far, best_pose_ate_sofar=best_pose_ate_sofar) + + best_so_far, best_pose_ate_sofar = misc.load_model(args=args, model_without_ddp=model_without_ddp, + optimizer=optimizer, loss_scaler=loss_scaler) + if best_so_far is None: + best_so_far = float('inf') + if best_pose_ate_sofar is None: + best_pose_ate_sofar = float('inf') + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + for epoch in range(args.start_epoch, args.epochs + 1): + + # Test on multiple datasets + new_best = False + new_pose_best = False + already_saved = False + if (epoch > args.start_epoch and args.eval_freq > 0 and epoch % args.eval_freq == 0) or args.eval_only: + test_stats = {} + for test_name, testset in data_loader_test.items(): + stats = test_one_epoch(model, test_criterion, testset, + device, epoch, log_writer=log_writer, args=args, prefix=test_name) + test_stats[test_name] = stats + + # Save best of all + if stats['loss_med'] < best_so_far: + best_so_far = stats['loss_med'] + new_best = True + + # Ensure that eval_pose_estimation is only run on the main process + if args.pose_eval_freq>0 and (epoch % args.pose_eval_freq==0 or args.eval_only): + ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation(args, model, device, save_dir=f'{args.output_dir}/{epoch}') + print(f'ATE mean: {ate_mean}, RPE trans mean: {rpe_trans_mean}, RPE rot mean: {rpe_rot_mean}') + + # Optionally log the results to wandb + if args.wandb and misc.is_main_process(): + wandb_dict = { + 'epoch': epoch, + 'ATE mean': ate_mean, + 'RPE trans mean': rpe_trans_mean, + 'RPE rot mean': rpe_rot_mean, + } + if args.save_pose_qualitative: + for outfile in outfile_list: + wandb_dict[outfile.split('/')[-1]] = wandb.Object3D(open(outfile)) + + wandb.log(wandb_dict) + + if ate_mean < best_pose_ate_sofar and not bug: # if the pose estimation is better, and w/o any error + best_pose_ate_sofar = ate_mean + new_pose_best = True + + # Synchronize all processes to ensure eval_pose_estimation is completed + try: + torch.distributed.barrier() + except: + pass + + # Save more stuff + write_log_stats(epoch, train_stats, test_stats) + + if args.eval_only and args.epochs <= 1: + exit(0) + + if epoch > args.start_epoch: + if args.keep_freq and epoch % args.keep_freq == 0: + save_model(epoch - 1, str(epoch), best_so_far, best_pose_ate_sofar) + already_saved = True + if new_best: + save_model(epoch - 1, 'best', best_so_far, best_pose_ate_sofar) + already_saved = True + if new_pose_best and args.save_best_pose: + save_model(epoch - 1, 'best_pose', best_so_far, best_pose_ate_sofar) + already_saved = True + + # Save immediately the last checkpoint + if epoch > args.start_epoch: + if args.save_freq and epoch % args.save_freq == 0 or epoch == args.epochs and not already_saved: + save_model(epoch - 1, 'last', best_so_far, best_pose_ate_sofar) + + if epoch >= args.epochs: + break # exit after writing last test to disk + + # Train + train_stats = train_one_epoch( + model, train_criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + log_writer=log_writer, + args=args) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + save_final_model(args, args.epochs, model_without_ddp, best_so_far=best_so_far) + + +def save_final_model(args, epoch, model_without_ddp, best_so_far=None): + output_dir = Path(args.output_dir) + checkpoint_path = output_dir / 'checkpoint-final.pth' + to_save = { + 'args': args, + 'model': model_without_ddp if isinstance(model_without_ddp, dict) else model_without_ddp.cpu().state_dict(), + 'epoch': epoch + } + if best_so_far is not None: + to_save['best_so_far'] = best_so_far + print(f'>> Saving model to {checkpoint_path} ...') + misc.save_on_master(to_save, checkpoint_path) + + +def build_dataset(dataset, batch_size, num_workers, test=False): + split = ['Train', 'Test'][test] + print(f'Building {split} Data loader for dataset: ', dataset) + loader = get_data_loader(dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=not (test), + drop_last=not (test)) + + print(f"{split} dataset length: ", len(loader)) + return loader + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Sized, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + args, + log_writer=None): + assert torch.backends.cuda.matmul.allow_tf32 == True + + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + accum_iter = args.accum_iter + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): + data_loader.sampler.set_epoch(epoch) + + optimizer.zero_grad() + + for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + epoch_f = epoch + data_iter_step / len(data_loader) + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, epoch_f, args) + + batch_result = loss_of_one_batch(batch, model, criterion, device, + symmetrize_batch=True, + use_amp=bool(args.amp)) + loss, loss_details = batch_result['loss'] # criterion returns two values + loss_value = float(loss) + + if (data_iter_step % max((len(data_loader) // args.num_save_visual), 1) == 0) and misc.is_main_process() : + save_dir = f'{args.output_dir}/{epoch}' + Path(save_dir).mkdir(parents=True, exist_ok=True) + view1, view2, pred1, pred2 = batch_result['view1'], batch_result['view2'], batch_result['pred1'], batch_result['pred2'] + gt_visual = visualize_results(view1, view2, pred1, pred2, save_dir=save_dir, visualize_type='gt') + pred_visual = visualize_results(view1, view2, pred1, pred2, save_dir=save_dir, visualize_type='pred') + + if args.wandb: + wandb.log({ + 'epoch': epoch, + 'train_visual_gt': wandb.Object3D(open(gt_visual)), + 'train_visual_pred': wandb.Object3D(open(pred_visual)) + }) + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value), force=True) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + del loss + del batch + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(epoch=epoch_f) + metric_logger.update(lr=lr) + metric_logger.update(loss=loss_value, **loss_details) + + if (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0: + loss_value_reduce = misc.all_reduce_mean(loss_value) # MUST BE EXECUTED BY ALL NODES + if log_writer is None: + continue + """ We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int(epoch_f * 1000) + log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('train_lr', lr, epoch_1000x) + log_writer.add_scalar('train_iter', epoch_1000x, epoch_1000x) + for name, val in loss_details.items(): + log_writer.add_scalar('train_' + name, val, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Sized, device: torch.device, epoch: int, + args, log_writer=None, prefix='test'): + + model.eval() + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9)) + header = 'Test Epoch: [{}]'.format(epoch) + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): + data_loader.dataset.set_epoch(epoch) if not args.fixed_eval_set else data_loader.dataset.set_epoch(0) + if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): + data_loader.sampler.set_epoch(epoch) if not args.fixed_eval_set else data_loader.sampler.set_epoch(0) + + for idx, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + + batch_result = loss_of_one_batch(batch, model, criterion, device, + symmetrize_batch=True, + use_amp=bool(args.amp)) + loss_tuple = batch_result['loss'] + loss_value, loss_details = loss_tuple # criterion returns two values + metric_logger.update(loss=float(loss_value), **loss_details) + + if args.num_save_visual>0 and (idx % max((len(data_loader) // args.num_save_visual), 1) == 0) and misc.is_main_process() : # save visualizations + save_dir = f'{args.output_dir}/{epoch}' + Path(save_dir).mkdir(parents=True, exist_ok=True) + view1, view2, pred1, pred2 = batch_result['view1'], batch_result['view2'], batch_result['pred1'], batch_result['pred2'] + gt_visual = visualize_results(view1, view2, pred1, pred2, save_dir=save_dir, visualize_type='gt') + pred_visual = visualize_results(view1, view2, pred1, pred2, save_dir=save_dir, visualize_type='pred') + + if args.wandb: + wandb.log({ + 'epoch': epoch, + 'test_visual_gt': wandb.Object3D(open(gt_visual)), + 'test_visual_pred': wandb.Object3D(open(pred_visual)) + }) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + aggs = [('avg', 'global_avg'), ('med', 'median')] + results = {f'{k}_{tag}': getattr(meter, attr) for k, meter in metric_logger.meters.items() for tag, attr in aggs} + + if log_writer is not None: + for name, val in results.items(): + log_writer.add_scalar(prefix + '_' + name, val, 1000 * epoch) + + return results diff --git a/monst3r/dust3r/utils/__init__.py b/monst3r/dust3r/utils/__init__.py new file mode 100644 index 0000000..a326921 --- /dev/null +++ b/monst3r/dust3r/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/monst3r/dust3r/utils/device.py b/monst3r/dust3r/utils/device.py new file mode 100644 index 0000000..e3b6a74 --- /dev/null +++ b/monst3r/dust3r/utils/device.py @@ -0,0 +1,76 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + ''' + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == 'numpy': + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): return todevice(x, 'numpy') +def to_cpu(x): return todevice(x, 'cpu') +def to_cuda(x): return todevice(x, 'cuda') + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) + + # otherwise, we just chain lists + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/monst3r/dust3r/utils/flow_vis.py b/monst3r/dust3r/utils/flow_vis.py new file mode 100644 index 0000000..e3d97f2 --- /dev/null +++ b/monst3r/dust3r/utils/flow_vis.py @@ -0,0 +1,184 @@ +import numpy as np +import cv2 +from functools import wraps +from matplotlib import pyplot as plt +import torch + +MAX_VALUES_BY_DTYPE = { + np.dtype('uint8'): 255, + np.dtype('uint16'): 65535, + np.dtype('uint32'): 4294967295, + np.dtype('float32'): 1.0, +} + +UNKNOWN_FLOW_THRESH = 1e7 +SMALLFLOW = 0.0 +LARGEFLOW = 1e8 + +def flow2rgb(flow_map, max_value): + if isinstance(flow_map,np.ndarray): + if flow_map.shape[2] == 2: + flow_map = flow_map.transpose(2,0, 1) + flow_map_np = flow_map + else: + if flow_map.shape[2] == 2: + # shape is HxWx2 + flow_map = flow_map.permute(2, 0, 1) + flow_map_np = flow_map.detach().cpu().numpy() + _, h, w = flow_map_np.shape + flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan') + rgb_map = np.ones((3,h,w)).astype(np.float32) + if max_value is not None: + normalized_flow_map = flow_map_np / max_value + else: + normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max()) + rgb_map[0] += normalized_flow_map[0] + rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1]) + rgb_map[2] += normalized_flow_map[1] + return rgb_map.clip(0,1) + + +def flow_to_image(flow, maxrad=None): + """ + Convert flow into middlebury color code image + :param flow: optical flow map + :return: optical flow image in middlebury color + """ + h,w, _ = flow.shape + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + if maxrad is None: + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + #print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) + + u = u/(maxrad + np.finfo(float).eps) + v = v/(maxrad + np.finfo(float).eps) + + img = compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + valid = np.ones((h,w), np.uint8) + valid[np.logical_and(u == 0 , v == 0)] = 0 + return np.uint8(img)*np.expand_dims(valid, axis=2) + + +def show_flow(flow): + """ + visualize optical flow map using matplotlib + :param filename: optical flow file + :return: None + """ + img = flow_to_image(flow) + plt.imshow(img) + plt.show() + + return img + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u**2+v**2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a+1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols+1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel,1)): + tmp = colorwheel[:, i] + col0 = tmp[k0-1] / 255 + col1 = tmp[k1-1] / 255 + col = (1-f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1-rad[idx]*(1-col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) + + return img + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) + colorwheel[col:col+YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) + colorwheel[col:col+CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col+MR, 0] = 255 + + return colorwheel + + diff --git a/monst3r/dust3r/utils/geometry.py b/monst3r/dust3r/utils/geometry.py new file mode 100644 index 0000000..7224728 --- /dev/null +++ b/monst3r/dust3r/utils/geometry.py @@ -0,0 +1,370 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# geometry utilitary functions +# -------------------------------------------------------- +import torch +import numpy as np +from scipy.spatial import cKDTree as KDTree + +from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans +from dust3r.utils.device import to_numpy + + +def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): + """ Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing='xy') + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = (depthmap > 0.0) + # Invalid any depth > 80m + valid_mask = valid_mask + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + if z_far > 0: + valid_mask = valid_mask & (depthmap < z_far) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + + return X_world, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None, ret_factor=False): + """ renorm pointmaps pts1, pts2 with norm_mode + """ + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split('_') + + if norm_mode == 'avg': + # gather all points together (joint normalization) + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == 'dis': + pass # do nothing + elif dis_mode == 'log1p': + all_dis = torch.log1p(all_dis) + elif dis_mode == 'warp-log1p': + # actually warp input points before normalizing them + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, :W1 * H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1 * H1:].view(-1, H2, W2, 1) + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f'bad {dis_mode=}') + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + # gather all points together (joint normalization) + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + + if norm_mode == 'avg': + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == 'median': + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == 'sqrt': + norm_factor = all_dis.sqrt().nanmean(dim=1)**2 + else: + raise ValueError(f'bad {norm_mode=}') + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + if ret_factor: + res = res + (norm_factor,) + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + # set invalid points to NaN + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True): + # set invalid points to NaN + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))) + reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) diff --git a/monst3r/dust3r/utils/goem_opt.py b/monst3r/dust3r/utils/goem_opt.py new file mode 100644 index 0000000..6d1827e --- /dev/null +++ b/monst3r/dust3r/utils/goem_opt.py @@ -0,0 +1,619 @@ +from matplotlib.pyplot import grid +import torch +from torch import nn +from torch.nn import functional as F +import math +from scipy.spatial.transform import Rotation + +def tum_to_pose_matrix(pose): + # pose: [tx, ty, tz, qw, qx, qy, qz] + assert pose.shape == (7,) + pose_xyzw = pose[[3, 4, 5, 6]] + r = Rotation.from_quat(pose_xyzw) + return np.vstack([np.hstack([r.as_matrix(), pose[:3].reshape(-1, 1)]), [0, 0, 0, 1]]) + +def depth_regularization_si_weighted(depth_pred, depth_init, pixel_wise_weight=None, pixel_wise_weight_scale=1, pixel_wise_weight_bias=1, eps=1e-6, pixel_weight_normalize=False): + # scale compute: + depth_pred = torch.clamp(depth_pred, min=eps) + depth_init = torch.clamp(depth_init, min=eps) + log_d_pred = torch.log(depth_pred) + log_d_init = torch.log(depth_init) + B, _, H, W = depth_pred.shape + scale = torch.sum(log_d_init - log_d_pred, + dim=[1, 2, 3], keepdim=True)/(H*W) + if pixel_wise_weight is not None: + if pixel_weight_normalize: + norm = torch.max(pixel_wise_weight.detach().view( + B, -1), dim=1, keepdim=False)[0] + pixel_wise_weight = pixel_wise_weight / \ + (norm[:, None, None, None]+eps) + pixel_wise_weight = pixel_wise_weight * \ + pixel_wise_weight_scale + pixel_wise_weight_bias + else: + pixel_wise_weight = 1 + si_loss = torch.sum(pixel_wise_weight*(log_d_pred - + log_d_init + scale)**2, dim=[1, 2, 3])/(H*W) + return si_loss.mean() + +class WarpImage(torch.nn.Module): + def __init__(self): + super(WarpImage, self).__init__() + self.base_coord = None + + def init_grid(self, shape, device): + H, W = shape + hh, ww = torch.meshgrid(torch.arange( + H).float(), torch.arange(W).float()) + coord = torch.zeros([1, H, W, 2]) + coord[0, ..., 0] = ww + coord[0, ..., 1] = hh + self.base_coord = coord.to(device) + self.W = W + self.H = H + + def warp_image(self, base_coord, img_1, flow_2_1): + B, C, H, W = flow_2_1.shape + sample_grids = base_coord + flow_2_1.permute([0, 2, 3, 1]) + sample_grids[..., 0] /= (W - 1) / 2 + sample_grids[..., 1] /= (H - 1) / 2 + sample_grids -= 1 + warped_image_2_from_1 = F.grid_sample( + img_1, sample_grids, align_corners=True) + return warped_image_2_from_1 + + def forward(self, img_1, flow_2_1): + B, _, H, W = flow_2_1.shape + if self.base_coord is None: + self.init_grid([H, W], device=flow_2_1.device) + base_coord = self.base_coord.expand([B, -1, -1, -1]) + return self.warp_image(base_coord, img_1, flow_2_1) + + +class CameraIntrinsics(nn.Module): + def __init__(self, init_focal_length=0.45, pixel_size=None): + super().__init__() + self.focal_length = nn.Parameter(torch.tensor(init_focal_length)) + self.pixel_size_buffer = pixel_size + + def register_shape(self, orig_shape, opt_shape) -> None: + self.orig_shape = orig_shape + self.opt_shape = opt_shape + H_orig, W_orig = orig_shape + H_opt, W_opt = opt_shape + if self.pixel_size_buffer is None: + # initialize as 35mm film + pixel_size = 0.433 / (H_orig ** 2 + W_orig ** 2) ** 0.5 + else: + pixel_size = self.pixel_size_buffer + self.register_buffer("pixel_size", torch.tensor(pixel_size)) + intrinsics_mat_buffer = torch.zeros(3, 3) + intrinsics_mat_buffer[0, -1] = (W_opt - 1) / 2 + intrinsics_mat_buffer[1, -1] = (H_opt - 1) / 2 + intrinsics_mat_buffer[2, -1] = 1 + self.register_buffer("intrinsics_mat", intrinsics_mat_buffer) + self.register_buffer("scale_H", torch.tensor( + H_opt / (H_orig * pixel_size))) + self.register_buffer("scale_W", torch.tensor( + W_opt / (W_orig * pixel_size))) + + def get_K_and_inv(self, with_batch_dim=True) -> torch.Tensor: + intrinsics_mat = self.intrinsics_mat.clone() + intrinsics_mat[0, 0] = self.focal_length * self.scale_W + intrinsics_mat[1, 1] = self.focal_length * self.scale_H + inv_intrinsics_mat = torch.linalg.inv(intrinsics_mat) + if with_batch_dim: + return intrinsics_mat[None, ...], inv_intrinsics_mat[None, ...] + else: + return intrinsics_mat, inv_intrinsics_mat + + +@torch.jit.script +def hat(v: torch.Tensor) -> torch.Tensor: + """ + Compute the Hat operator [1] of a batch of 3D vectors. + + Args: + v: Batch of vectors of shape `(minibatch , 3)`. + + Returns: + Batch of skew-symmetric matrices of shape + `(minibatch, 3 , 3)` where each matrix is of the form: + `[ 0 -v_z v_y ] + [ v_z 0 -v_x ] + [ -v_y v_x 0 ]` + + Raises: + ValueError if `v` is of incorrect shape. + + [1] https://en.wikipedia.org/wiki/Hat_operator + """ + + N, dim = v.shape + # if dim != 3: + # raise ValueError("Input vectors have to be 3-dimensional.") + + h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device) + + x, y, z = v.unbind(1) + + h[:, 0, 1] = -z + h[:, 0, 2] = y + h[:, 1, 0] = z + h[:, 1, 2] = -x + h[:, 2, 0] = -y + h[:, 2, 1] = x + + return h + + +@torch.jit.script +def get_relative_transform(src_R, src_t, tgt_R, tgt_t): + tgt_R_inv = tgt_R.permute([0, 2, 1]) + relative_R = torch.matmul(tgt_R_inv, src_R) + relative_t = torch.matmul(tgt_R_inv, src_t - tgt_t) + return relative_R, relative_t + + +def reproject_depth(src_R, src_t, src_disp, tgt_R, tgt_t, tgt_disp, K_src, K_inv_src, K_trg, K_inv_trg, coord, eps=1e-6): + """ + Convert the depth map's value to another camera pose. + input: + src_R: rotation matrix of source camera + src_t: translation vector of source camera + tgt_R: rotation matrix of target camera + tgt_t: translation vector of target camera + K: intrinsics matrix of the camera + src_disp: disparity map of source camera + tgt_disp: disparity map of target camera + coord: coordinate grids + K_inv: inverse intrinsics matrix of the camera + output: + tgt_depth_from_src: source depth map reprojected to target camera, values are ready for warping. + src_depth_from_tgt: target depth map reprojected to source camera, values are ready for warping. + """ + B, _, H, W = src_disp.shape + + src_depth = 1/(src_disp + eps) + tgt_depth = 1/(tgt_disp + eps) + # project 1 to 2 + + src_depth_flat = src_depth.view([B, 1, H*W]) + src_xyz = src_depth_flat * src_R.matmul(K_inv_src.matmul(coord)) + src_t + src_xyz_at_tgt_cam = K_trg.matmul( + tgt_R.transpose(1, 2).matmul(src_xyz - tgt_t)) + tgt_depth_from_src = src_xyz_at_tgt_cam[:, 2, :].view([B, 1, H, W]) + # project 2 to 1 + tgt_depth_flat = tgt_depth.view([B, 1, H*W]) + tgt_xyz = tgt_depth_flat * tgt_R.matmul(K_inv_trg.matmul(coord)) + tgt_t + tgt_xyz_at_src_cam = K_src.matmul( + src_R.transpose(1, 2).matmul(tgt_xyz - src_t)) + src_depth_from_tgt = tgt_xyz_at_src_cam[:, 2, :].view([B, 1, H, W]) + return tgt_depth_from_src, src_depth_from_tgt + + +# @torch.jit.script +def warp_by_disp(src_R, src_t, tgt_R, tgt_t, K, src_disp, coord, inv_K, debug_mode=False, use_depth=False): + + if debug_mode: + B, C, H, W = src_disp.shape + relative_R, relative_t = get_relative_transform( + src_R, src_t, tgt_R, tgt_t) + + print(relative_t.shape) + H_mat = K.matmul(relative_R.matmul(inv_K)) # Nx3x3 + flat_disp = src_disp.view([B, 1, H * W]) # Nx1xNpoints + relative_t_flat = relative_t.expand([-1, -1, H*W]) + rot_coord = torch.matmul(H_mat, coord) + tr_coord = flat_disp * \ + torch.matmul(K, relative_t_flat) + tgt_coord = rot_coord + tr_coord + normalization_factor = (tgt_coord[:, 2:, :] + 1e-6) + rot_coord_normalized = rot_coord / normalization_factor + tr_coord_normalized = tr_coord / normalization_factor + tgt_coord_normalized = rot_coord_normalized + tr_coord_normalized + debug_info = {} + debug_info['tr_coord_normalized'] = tr_coord_normalized + debug_info['rot_coord_normalized'] = rot_coord_normalized + debug_info['tgt_coord_normalized'] = tgt_coord_normalized + debug_info['tr_coord'] = tr_coord + debug_info['rot_coord'] = rot_coord + debug_info['normalization_factor'] = normalization_factor + debug_info['relative_t_flat'] = relative_t_flat + return (tgt_coord_normalized - coord).view([B, 3, H, W]), debug_info + else: + B, C, H, W = src_disp.shape + relative_R, relative_t = get_relative_transform( + src_R, src_t, tgt_R, tgt_t) + H_mat = K.matmul(relative_R.matmul(inv_K)) # Nx3x3 + flat_disp = src_disp.view([B, 1, H * W]) # Nx1xNpoints + if use_depth: + tgt_coord = flat_disp * torch.matmul(H_mat, coord) + \ + torch.matmul(K, relative_t) + else: + tgt_coord = torch.matmul(H_mat, coord) + flat_disp * \ + torch.matmul(K, relative_t) + tgt_coord = tgt_coord / (tgt_coord[:, -1:, :] + 1e-6) + return (tgt_coord - coord).view([B, 3, H, W]), tgt_coord + + +def unproject_depth(depth, K_inv, R, t, coord): + # this need verification + B, _, H, W = depth.shape + disp_flat = depth.view([B, 1, H * W]) + xyz = disp_flat * R.matmul(K_inv.matmul(coord)) + t + return xyz.reshape([B, 3, H, W]) + + +@torch.jit.script +def _so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001): + """ + A helper function that computes the so3 exponential map and, + apart from the rotation matrix, also returns intermediate variables + that can be re-used in other functions. + """ + _, dim = log_rot.shape + # if dim != 3: + # raise ValueError("Input tensor shape has to be Nx3.") + + nrms = (log_rot * log_rot).sum(1) + # phis ... rotation angles + rot_angles = torch.clamp(nrms, eps).sqrt() + rot_angles_inv = 1.0 / rot_angles + fac1 = rot_angles_inv * rot_angles.sin() + fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos()) + skews = hat(log_rot) + skews_square = torch.bmm(skews, skews) + + R = ( + # pyre-fixme[16]: `float` has no attribute `__getitem__`. + fac1[:, None, None] * skews + + fac2[:, None, None] * skews_square + + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None] + ) + + return R, rot_angles, skews, skews_square + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +class CameraPoseDeltaCollection(torch.nn.Module): + def __init__(self, number_of_points=10) -> None: + super().__init__() + zero_rotation = torch.ones([1, 3]) * 1e-3 + zero_translation = torch.zeros([1, 3, 1]) + 1e-4 + for n in range(number_of_points): + self.register_parameter( + f"delta_rotation_{n}", nn.Parameter(zero_rotation)) + self.register_parameter( + f"delta_translation_{n}", nn.Parameter(zero_translation) + ) + self.register_buffer("zero_rotation", torch.eye(3)[None, ...]) + self.register_buffer("zero_translation", torch.zeros([1, 3, 1])) + self.traced_so3_exp_map = None + self.number_of_points = number_of_points + + def get_rotation_and_translation_params(self): + rotation_params = [] + translation_params = [] + for n in range(self.number_of_points): + rotation_params.append(getattr(self, f"delta_rotation_{n}")) + translation_params.append(getattr(self, f"delta_translation_{n}")) + return rotation_params, translation_params + + def set_rotation_and_translation(self, index, rotaion_so3, translation): + delta_rotation = getattr(self, f"delta_rotation_{index}") + delta_translation = getattr(self, f"delta_translation_{index}") + delta_rotation.data = rotaion_so3.detach().clone() + delta_translation.data = translation.detach().clone() + + def set_first_frame_pose(self, R, t): + self.zero_rotation.data = R.detach().clone().reshape([1, 3, 3]) + self.zero_translation.data = t.detach().clone().reshape([1, 3, 1]) + + def get_raw_value(self, index): + so3 = getattr(self, f"delta_rotation_{index}") + translation = getattr(self, f"delta_translation_{index}") + return so3, translation + + def forward(self, list_of_index): + se_3 = [] + t_out = [] + for idx in list_of_index: + delta_rotation, delta_translation = self.get_raw_value(idx) + se_3.append(delta_rotation) + t_out.append(delta_translation) + se_3 = torch.cat(se_3, dim=0) + t_out = torch.cat(t_out, dim=0) + if self.traced_so3_exp_map is None: + self.traced_so3_exp_map = torch.jit.trace( + _so3_exp_map, (se_3,)) + R_out = _so3_exp_map(se_3)[0] + return R_out, t_out + + def forward_index(self, index): + # if index == 0: + # return self.zero_rotation, self.zero_translation + # else: + delta_rotation, delta_translation = self.get_raw_value(index) + if self.traced_so3_exp_map is None: + self.traced_so3_exp_map = torch.jit.trace( + _so3_exp_map, (delta_rotation,)) + R = _so3_exp_map(delta_rotation)[0] + return R, delta_translation + + +class DepthScaleShiftCollection(torch.nn.Module): + def __init__(self, n_points=10, use_inverse=False, grid_size=1): + super().__init__() + self.grid_size = grid_size + for n in range(n_points): + self.register_parameter( + f"shift_{n}", nn.Parameter(torch.FloatTensor([0.0])) + ) + self.register_parameter( + f"scale_{n}", nn.Parameter( + torch.ones([1, 1, grid_size, grid_size])) + ) + + self.use_inverse = use_inverse + self.output_shape = None + + def set_outputshape(self, output_shape): + self.output_shape = output_shape + + def forward(self, index): + shift = getattr(self, f"shift_{index}") + scale = getattr(self, f"scale_{index}") + if self.use_inverse: + scale = torch.exp(scale) # 1 / (scale ** 4) + if self.grid_size != 1: + scale = F.interpolate(scale, self.output_shape, + mode='bilinear', align_corners=True) + return scale, shift + + def set_scale(self, index, scale): + scale_param = getattr(self, f"scale_{index}") + if self.use_inverse: + scale = math.log(scale) # (1 / scale) ** 0.25 + scale_param.data.fill_(scale) + + def get_scale_data(self, index): + scale = getattr(self, f"scale_{index}").data + if self.use_inverse: + scale = torch.exp(scale) # 1 / (scale ** 4) + if self.grid_size != 1: + scale = F.interpolate(scale, self.output_shape, + mode='bilinear', align_corners=True) + return scale + + +def check_R_shape(R): + r0, r1, r2 = R.shape + assert r1 == 3 and r2 == 3 + + +def check_t_shape(t): + t0, t1, t2 = t.shape + assert t1 == 3 and t2 == 1 + + +class DepthBasedWarping(nn.Module): + # tested + def __init__(self) -> None: + super().__init__() + + def generate_grid(self, H, W, device): + yy, xx = torch.meshgrid( + torch.arange(H, device=device, dtype=torch.float32), + torch.arange(W, device=device, dtype=torch.float32), + ) + self.coord = torch.ones( + [1, 3, H, W], device=device, dtype=torch.float32) + self.coord[0, 0, ...] = xx + self.coord[0, 1, ...] = yy + self.coord = self.coord.reshape([1, 3, H * W]) + self.jitted_warp_by_disp = None + + def reproject_depth(self, src_R, src_t, src_disp, tgt_R, tgt_t, tgt_disp, K_src, K_inv_src, K_trg, K_inv_trg, eps=1e-6, check_shape=False): + if check_shape: + check_R_shape(src_R) + check_R_shape(tgt_R) + check_t_shape(src_t) + check_t_shape(tgt_t) + check_t_shape(src_disp) + check_t_shape(tgt_disp) + device = src_disp.device + B, _, H, W = src_disp.shape + if not hasattr(self, "coord"): + self.generate_grid(src_disp.shape[2], src_disp.shape[3], device) + else: + if self.coord.shape[-1] != H * W: + del self.coord + self.generate_grid(H, W, device) + return reproject_depth(src_R, src_t, src_disp, tgt_R, tgt_t, tgt_disp, K_src, K_inv_src, K_trg, K_inv_trg, self.coord, eps=eps) + + def unproject_depth(self, disp, R, t, K_inv, eps=1e-6, check_shape=False): + if check_shape: + check_R_shape(R) + check_R_shape(t) + + _, _, H, W = disp.shape + B = R.shape[0] + device = disp.device + if not hasattr(self, "coord"): + self.generate_grid(H, W, device=device) + else: + if self.coord.shape[-1] != H * W: + del self.coord + self.generate_grid(H, W, device=device) + # if self.jitted_warp_by_disp is None: + # self.jitted_warp_by_disp = torch.jit.trace( + # warp_by_disp, (src_R.detach(), src_t.detach(), tgt_R.detach(), tgt_t.detach(), K, src_disp.detach(), self.coord, inv_K)) + return unproject_depth(1 / (disp + eps), K_inv, R, t, self.coord) + + def forward( + self, + src_R, + src_t, + tgt_R, + tgt_t, + src_disp, + K, + inv_K, + eps=1e-6, + use_depth=False, + check_shape=False, + debug_mode=False, + ): + """warp the current depth frame and generate flow field. + + Args: + src_R (FloatTensor): 1x3x3 + src_t (FloatTensor): 1x3x1 + tgt_R (FloatTensor): Nx3x3 + tgt_t (FloatTensor): Nx3x1 + src_disp (FloatTensor): Nx1XHxW + src_K (FloatTensor): 1x3x3 + """ + if check_shape: + check_R_shape(src_R) + check_R_shape(tgt_R) + check_t_shape(src_t) + check_t_shape(tgt_t) + + _, _, H, W = src_disp.shape + B = tgt_R.shape[0] + device = src_disp.device + if not hasattr(self, "coord"): + self.generate_grid(H, W, device=device) + else: + if self.coord.shape[-1] != H * W: + del self.coord + self.generate_grid(H, W, device=device) + # if self.jitted_warp_by_disp is None: + # self.jitted_warp_by_disp = torch.jit.trace( + # warp_by_disp, (src_R.detach(), src_t.detach(), tgt_R.detach(), tgt_t.detach(), K, src_disp.detach(), self.coord, inv_K)) + + return warp_by_disp(src_R, src_t, tgt_R, tgt_t, K, src_disp, self.coord, inv_K, debug_mode, use_depth) + + +class DepthToXYZ(nn.Module): + # tested + def __init__(self) -> None: + super().__init__() + + def generate_grid(self, H, W, device): + yy, xx = torch.meshgrid( + torch.arange(H, device=device, dtype=torch.float32), + torch.arange(W, device=device, dtype=torch.float32), + ) + self.coord = torch.ones( + [1, 3, H, W], device=device, dtype=torch.float32) + self.coord[0, 0, ...] = xx + self.coord[0, 1, ...] = yy + self.coord = self.coord.reshape([1, 3, H * W]) + + def forward(self, disp, K_inv, R, t, eps=1e-6, check_shape=False): + """warp the current depth frame and generate flow field. + + Args: + src_R (FloatTensor): 1x3x3 + src_t (FloatTensor): 1x3x1 + tgt_R (FloatTensor): Nx3x3 + tgt_t (FloatTensor): Nx3x1 + src_disp (FloatTensor): Nx1XHxW + src_K (FloatTensor): 1x3x3 + """ + if check_shape: + check_R_shape(R) + check_R_shape(t) + + _, _, H, W = disp.shape + B = R.shape[0] + device = disp.device + if not hasattr(self, "coord"): + self.generate_grid(H, W, device=device) + else: + if self.coord.shape[-1] != H * W: + del self.coord + self.generate_grid(H, W, device=device) + # if self.jitted_warp_by_disp is None: + # self.jitted_warp_by_disp = torch.jit.trace( + # warp_by_disp, (src_R.detach(), src_t.detach(), tgt_R.detach(), tgt_t.detach(), K, src_disp.detach(), self.coord, inv_K)) + + return unproject_depth(1 / (disp + eps), K_inv, R, t, self.coord) + +class OccMask(torch.nn.Module): + def __init__(self, th=3): + super(OccMask, self).__init__() + self.th = th + self.base_coord = None + + def init_grid(self, shape, device): + H, W = shape + hh, ww = torch.meshgrid(torch.arange( + H).float(), torch.arange(W).float()) + coord = torch.zeros([1, H, W, 2]) + coord[0, ..., 0] = ww + coord[0, ..., 1] = hh + self.base_coord = coord.to(device) + self.W = W + self.H = H + + @torch.no_grad() + def get_oob_mask(self, base_coord, flow_1_2): + target_range = base_coord + flow_1_2.permute([0, 2, 3, 1]) + oob_mask = (target_range[..., 0] < 0) | (target_range[..., 0] > self.W-1) | ( + target_range[..., 1] < 0) | (target_range[..., 1] > self.H-1) + return ~oob_mask[:, None, ...] + + @torch.no_grad() + def get_flow_inconsistency_tensor(self, base_coord, flow_1_2, flow_2_1): + B, C, H, W = flow_1_2.shape + sample_grids = base_coord + flow_1_2.permute([0, 2, 3, 1]) + sample_grids[..., 0] /= (W - 1) / 2 + sample_grids[..., 1] /= (H - 1) / 2 + sample_grids -= 1 + sampled_flow = F.grid_sample( + flow_2_1, sample_grids, align_corners=True) + return torch.abs((sampled_flow+flow_1_2).sum(1, keepdim=True)) + + def forward(self, flow_1_2, flow_2_1): + B, _, H, W = flow_1_2.shape + if self.base_coord is None: + self.init_grid([H, W], device=flow_1_2.device) + base_coord = self.base_coord.expand([B, -1, -1, -1]) + oob_mask = self.get_oob_mask(base_coord, flow_1_2) + flow_inconsistency_tensor = self.get_flow_inconsistency_tensor( + base_coord, flow_1_2, flow_2_1) + valid_flow_mask = flow_inconsistency_tensor < self.th + return valid_flow_mask*oob_mask \ No newline at end of file diff --git a/monst3r/dust3r/utils/image.py b/monst3r/dust3r/utils/image.py new file mode 100644 index 0000000..52d3312 --- /dev/null +++ b/monst3r/dust3r/utils/image.py @@ -0,0 +1,322 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions about images (loading/converting...) +# -------------------------------------------------------- +import os +import torch +import numpy as np +import PIL.Image +from PIL.ImageOps import exif_transpose +import torchvision.transforms as tvf +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import glob +import imageio +import matplotlib.pyplot as plt + +try: + from pillow_heif import register_heif_opener # noqa + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) +ToTensor = tvf.ToTensor() +TAG_FLOAT = 202021.25 + +def depth_read(filename): + """ Read depth data from file, return as numpy array. """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + width = np.fromfile(f,dtype=np.int32,count=1)[0] + height = np.fromfile(f,dtype=np.int32,count=1)[0] + size = width*height + assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) + depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) + return depth + +def cam_read(filename): + """ Read camera data, return (M,N) tuple. + + M is the intrinsic matrix, N is the extrinsic matrix, so that + + x = M*N*X, + with x being a point in homogeneous image pixel coordinates, X being a + point in homogeneous world coordinates. + """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) + N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) + return M,N + +def flow_read(filename): + """ Read optical flow from file, return (U,V) tuple. + + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + width = np.fromfile(f,dtype=np.int32,count=1)[0] + height = np.fromfile(f,dtype=np.int32,count=1)[0] + size = width*height + assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) + tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) + u = tmp[:,np.arange(width)*2] + v = tmp[:,np.arange(width)*2 + 1] + return u,v + +def img_to_arr( img ): + if isinstance(img, str): + img = imread_cv2(img) + return img + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """ Open an image or a depthmap with opencv-python. + """ + if path.endswith(('.exr', 'EXR')): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f'Could not load image={path} with {options=}') + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size, nearest=False): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS if not nearest else PIL.Image.NEAREST + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) + return img.resize(new_size, interp) + + +def crop_img(img, size, square_ok=False, nearest=False, crop=True): + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)), nearest=nearest) + else: + # resize long side to 512 + img = _resize_pil_image(img, size, nearest=nearest) + W, H = img.size + cx, cy = W//2, H//2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx-half, cy-half, cx+half, cy+half)) + else: + halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 + if not (square_ok) and W == H: + halfh = 3*halfw/4 + if crop: + img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) + else: # resize + img = img.resize((2*halfw, 2*halfh), PIL.Image.LANCZOS) + return img + +def load_images(folder_or_list, size, square_ok=False, verbose=True, dynamic_mask_root=None, crop=True, fps=0, num_frames=110): + """Open and convert all images or videos in a list or folder to proper input format for DUSt3R.""" + if isinstance(folder_or_list, str): + if verbose: + print(f'>> Loading images from {folder_or_list}') + # if folder_or_list is a folder, load all images in the folder + if os.path.isdir(folder_or_list): + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + else: # the folder_content will be the folder_or_list itself + root, folder_content = '', [folder_or_list] + + elif isinstance(folder_or_list, list): + if verbose: + print(f'>> Loading a list of {len(folder_or_list)} items') + root, folder_content = '', folder_or_list + + else: + raise ValueError(f'Bad input {folder_or_list=} ({type(folder_or_list)})') + + supported_images_extensions = ['.jpg', '.jpeg', '.png'] + supported_video_extensions = ['.mp4', '.avi', '.mov'] + if heif_support_enabled: + supported_images_extensions += ['.heic', '.heif'] + supported_images_extensions = tuple(supported_images_extensions) + supported_video_extensions = tuple(supported_video_extensions) + + imgs = [] + # Sort items by their names + folder_content = sorted(folder_content, key=lambda x: x.split('/')[-1]) + for path in folder_content: + full_path = os.path.join(root, path) + if path.lower().endswith(supported_images_extensions): + # Process image files + img = exif_transpose(PIL.Image.open(full_path)).convert('RGB') + W1, H1 = img.size + img = crop_img(img, size, square_ok=square_ok, crop=crop) + W2, H2 = img.size + + if verbose: + print(f' - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') + + single_dict = dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=full_path, + mask=~(ToTensor(img)[None].sum(1) <= 0.01) + ) + + if dynamic_mask_root is not None: + dynamic_mask_path = os.path.join(dynamic_mask_root, os.path.basename(path)) + else: # Sintel dataset handling + dynamic_mask_path = full_path.replace('final', 'dynamic_label_perfect').replace('clean', 'dynamic_label_perfect') + + if os.path.exists(dynamic_mask_path): + dynamic_mask = PIL.Image.open(dynamic_mask_path).convert('L') + dynamic_mask = crop_img(dynamic_mask, size, square_ok=square_ok) + dynamic_mask = ToTensor(dynamic_mask)[None].sum(1) > 0.99 # "1" means dynamic + if dynamic_mask.sum() < 0.8 * dynamic_mask.numel(): # Consider static if over 80% is dynamic + single_dict['dynamic_mask'] = dynamic_mask + else: + single_dict['dynamic_mask'] = torch.zeros_like(single_dict['mask']) + else: + single_dict['dynamic_mask'] = torch.zeros_like(single_dict['mask']) + + imgs.append(single_dict) + + elif path.lower().endswith(supported_video_extensions): + # Process video files + if verbose: + print(f'>> Loading video from {full_path}') + cap = cv2.VideoCapture(full_path) + if not cap.isOpened(): + print(f'Error opening video file {full_path}') + continue + + video_fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if video_fps == 0: + print(f'Error: Video FPS is 0 for {full_path}') + cap.release() + continue + if fps > 0: + frame_interval = max(1, int(round(video_fps / fps))) + else: + frame_interval = 1 + frame_indices = list(range(0, total_frames, frame_interval)) + if num_frames is not None: + frame_indices = frame_indices[:num_frames] + + if verbose: + print(f' - Video FPS: {video_fps}, Frame Interval: {frame_interval}, Total Frames to Read: {len(frame_indices)}') + + for frame_idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + ret, frame = cap.read() + if not ret: + break # End of video + + img = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + W1, H1 = img.size + img = crop_img(img, size, square_ok=square_ok, crop=crop) + W2, H2 = img.size + + if verbose: + print(f' - Adding frame {frame_idx} from {path} with resolution {W1}x{H1} --> {W2}x{H2}') + + single_dict = dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=f'{full_path}_frame_{frame_idx}', + mask=~(ToTensor(img)[None].sum(1) <= 0.01) + ) + + # Dynamic masks for video frames are set to zeros by default + single_dict['dynamic_mask'] = torch.zeros_like(single_dict['mask']) + + imgs.append(single_dict) + + cap.release() + + else: + continue # Skip unsupported file types + + assert imgs, 'No images found at ' + root + if verbose: + print(f' (Found {len(imgs)} images)') + return imgs + +def enlarge_seg_masks(folder, kernel_size=5, prefix="dynamic_mask"): + mask_pathes = glob.glob(f'{folder}/{prefix}_*.png') + for mask_path in mask_pathes: + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + kernel = np.ones((kernel_size, kernel_size),np.uint8) + enlarged_mask = cv2.dilate(mask, kernel, iterations=1) + cv2.imwrite(mask_path.replace(prefix, 'enlarged_dynamic_mask'), enlarged_mask) + +def show_mask(mask, ax, obj_id=None, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + cmap = plt.get_cmap("tab10") + cmap_idx = 1 if obj_id is None else obj_id + color = np.array([*cmap(cmap_idx)[:3], 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + +def get_overlaied_gif(folder, img_format="frame_*.png", mask_format="dynamic_mask_*.png", output_path="_overlaied.gif"): + img_paths = glob.glob(f'{folder}/{img_format}') + mask_paths = glob.glob(f'{folder}/{mask_format}') + assert len(img_paths) == len(mask_paths), f"Number of images and masks should be the same, got {len(img_paths)} images and {len(mask_paths)} masks" + img_paths = sorted(img_paths) + mask_paths = sorted(mask_paths, key=lambda x: int(x.split('_')[-1].split('.')[0])) + frames = [] + for img_path, mask_path in zip(img_paths, mask_paths): + # Read image and convert to RGB for Matplotlib + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # Read mask and normalize + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + mask = mask.astype(np.float32) / 255.0 + # Create figure and axis + fig, ax = plt.subplots(figsize=(img.shape[1]/100, img.shape[0]/100), dpi=100) + ax.imshow(img) + # Overlay mask using show_mask + show_mask(mask, ax) + ax.axis('off') + # Render the figure to a numpy array + fig.canvas.draw() + img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + frames.append(img_array) + plt.close(fig) # Close the figure to free memory + # Save frames as a GIF using imageio + imageio.mimsave(os.path.join(folder,output_path), frames, fps=10) diff --git a/monst3r/dust3r/utils/misc.py b/monst3r/dust3r/utils/misc.py new file mode 100644 index 0000000..be0051e --- /dev/null +++ b/monst3r/dust3r/utils/misc.py @@ -0,0 +1,200 @@ +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import torch +import cv2 +import numpy as np +from dust3r.utils.vo_eval import save_trajectory_tum_format +from PIL import Image + +def get_stride_distribution(strides, dist_type='uniform'): + + # input strides sorted by descreasing order by default + + if dist_type == 'uniform': + dist = np.ones(len(strides)) / len(strides) + elif dist_type == 'exponential': + lambda_param = 1.0 + dist = np.exp(-lambda_param * np.arange(len(strides))) + elif dist_type.startswith('linear'): # e.g., linear_1_2 + try: + start, end = map(float, dist_type.split('_')[1:]) + dist = np.linspace(start, end, len(strides)) + except ValueError: + raise ValueError(f'Invalid linear distribution format: {dist_type}') + else: + raise ValueError('Unknown distribution type %s' % dist_type) + + # normalize to sum to 1 + return dist / np.sum(dist) + + +def fill_default_args(kwargs, func): + import inspect # a bit hacky but it works reliably + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + # module is directly a parameter + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1['instance'] + y = gt2['instance'] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def flip(tensor): + """ flip so that tensor[0::2] <=> tensor[1::2] """ + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """ Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + def wrapper_no(decout, true_shape): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W)) + return res + + def wrapper_yes(decout, true_shape): + B = len(true_shape) + # by definition, the batch is in landscape mode so W >= H + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # true_shape = true_shape.cpu() + if is_landscape.all(): + return head(decout, (H, W)) + if is_portrait.all(): + return transposed(head(decout, (W, H))) + + # batch is a mix of both portraint & landscape + def selout(ar): return [d[ar] for d in decout] + l_result = head(selout(is_landscape), (H, W)) + p_result = transposed(head(selout(is_portrait), (W, H))) + + # allocate full result + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float('nan') + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz + +def save_tum_poses(traj, path): + # traj = self.get_tum_poses() + save_trajectory_tum_format(traj, path) + return traj[0] # return the poses + +def save_focals(focals, path): + # convert focal to txt + # focals = self.get_focals() + np.savetxt(path, focals.detach().cpu().numpy(), fmt='%.6f') + return focals + +def save_intrinsics(K_raw, path): + # K_raw = self.get_intrinsics() + K = K_raw.reshape(-1, 9) + np.savetxt(path, K.detach().cpu().numpy(), fmt='%.6f') + return K_raw + +def save_conf_maps(conf, path): + # conf = self.get_conf() + for i, c in enumerate(conf): + np.save(f'{path}/conf_{i}.npy', c.detach().cpu().numpy()) + return conf + +def save_rgb_imgs(imgs, path): + # imgs = self.imgs + for i, img in enumerate(imgs): + # convert from rgb to bgr + img = img[..., ::-1] + cv2.imwrite(f'{path}/frame_{i:04d}.png', img*255) + return imgs + +def save_dynamic_masks(dynamic_masks, path): + # dynamic_masks = self.dynamic_masks + for i, dynamic_mask in enumerate(dynamic_masks): + cv2.imwrite(f'{path}/dynamic_mask_{i}.png', (dynamic_mask * 255).detach().cpu().numpy().astype(np.uint8)) + return dynamic_masks + +def save_depth_maps(depth_maps, path): + images = [] + for i, depth_map in enumerate(depth_maps): + depth_map_colored = cv2.applyColorMap((depth_map * 255).detach().cpu().numpy().astype(np.uint8), cv2.COLORMAP_JET) + img_path = f'{path}/frame_{(i):04d}.png' + cv2.imwrite(img_path, depth_map_colored) + images.append(Image.open(img_path)) + # Save npy file + np.save(f'{path}/frame_{(i):04d}.npy', depth_map.detach().cpu().numpy()) + + # Save gif using Pillow + images[0].save(f'{path}/_depth_maps.gif', save_all=True, append_images=images[1:], duration=100, loop=0) + return depth_maps + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.detach().cpu() + if isinstance(x, list): + return [to_cpu(xx) for xx in x] \ No newline at end of file diff --git a/monst3r/dust3r/utils/parallel.py b/monst3r/dust3r/utils/parallel.py new file mode 100644 index 0000000..06ae7fe --- /dev/null +++ b/monst3r/dust3r/utils/parallel.py @@ -0,0 +1,79 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for multiprocessing +# -------------------------------------------------------- +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import cpu_count + + +def parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw): + """ tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + while workers <= 0: + workers += cpu_count() + if workers == 1: + front_num = float('inf') + + # convert into an iterable + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + # sequential execution first + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # end of the iterable + front.append(function(*a) if star_args else function(**a) if kw_args else function(a)) + + # then parallel execution + out = [] + with Pool(workers) as pool: + # Pass the elements of args into function + if star_args: + futures = pool.imap(starcall, [(function, a) for a in args]) + elif kw_args: + futures = pool.imap(starstarcall, [(function, a) for a in args]) + else: + futures = pool.imap(function, args) + # Print out the progress as tasks complete + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """ Same as parallel_threads, with processes + """ + import multiprocessing as mp + kwargs['Pool'] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(*args) + + +def starstarcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(**args) diff --git a/monst3r/dust3r/utils/path_to_croco.py b/monst3r/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000..39226ce --- /dev/null +++ b/monst3r/dust3r/utils/path_to_croco.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# CroCo submodule import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco')) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models') +# check the presence of models directory in repo to be sure its cloned +if path.isdir(CROCO_MODELS_PATH): + # workaround for sibling import + sys.path.insert(0, CROCO_REPO_PATH) +else: + raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?") diff --git a/monst3r/dust3r/utils/po_utils/__init__.py b/monst3r/dust3r/utils/po_utils/__init__.py new file mode 100644 index 0000000..9f72a50 --- /dev/null +++ b/monst3r/dust3r/utils/po_utils/__init__.py @@ -0,0 +1 @@ +# junyi \ No newline at end of file diff --git a/monst3r/dust3r/utils/po_utils/basic.py b/monst3r/dust3r/utils/po_utils/basic.py new file mode 100644 index 0000000..fbfa069 --- /dev/null +++ b/monst3r/dust3r/utils/po_utils/basic.py @@ -0,0 +1,397 @@ +import os +import numpy as np +from os.path import isfile +import torch +import torch.nn.functional as F +EPS = 1e-6 +import copy + +def sub2ind(height, width, y, x): + return y*width + x + +def ind2sub(height, width, ind): + y = ind // width + x = ind % width + return y, x + +def get_lr_str(lr): + lrn = "%.1e" % lr # e.g., 5.0e-04 + lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4 + return lrn + +def strnum(x): + s = '%g' % x + if '.' in s: + if x < 1.0: + s = s[s.index('.'):] + s = s[:min(len(s),4)] + return s + +def assert_same_shape(t1, t2): + for (x, y) in zip(list(t1.shape), list(t2.shape)): + assert(x==y) + +def print_stats(name, tensor): + shape = tensor.shape + tensor = tensor.detach().cpu().numpy() + print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) + +def print_stats_py(name, tensor): + shape = tensor.shape + print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) + +def print_(name, tensor): + tensor = tensor.detach().cpu().numpy() + print(name, tensor, tensor.shape) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def normalize_single(d): + # d is a whatever shape torch tensor + dmin = torch.min(d) + dmax = torch.max(d) + d = (d-dmin)/(EPS+(dmax-dmin)) + return d + +def normalize(d): + # d is B x whatever. normalize within each element of the batch + out = torch.zeros(d.size()) + if d.is_cuda: + out = out.cuda() + B = list(d.size())[0] + for b in list(range(B)): + out[b] = normalize_single(d[b]) + return out + +def hard_argmax2d(tensor): + B, C, Y, X = list(tensor.shape) + assert(C==1) + + # flatten the Tensor along the height and width axes + flat_tensor = tensor.reshape(B, -1) + # argmax of the flat tensor + argmax = torch.argmax(flat_tensor, dim=1) + + # convert the indices into 2d coordinates + argmax_y = torch.floor(argmax / X) # row + argmax_x = argmax % X # col + + argmax_y = argmax_y.reshape(B) + argmax_x = argmax_x.reshape(B) + return argmax_y, argmax_x + +def argmax2d(heat, hard=True): + B, C, Y, X = list(heat.shape) + assert(C==1) + + if hard: + # hard argmax + loc_y, loc_x = hard_argmax2d(heat) + loc_y = loc_y.float() + loc_x = loc_x.float() + else: + heat = heat.reshape(B, Y*X) + prob = torch.nn.functional.softmax(heat, dim=1) + + grid_y, grid_x = meshgrid2d(B, Y, X) + + grid_y = grid_y.reshape(B, -1) + grid_x = grid_x.reshape(B, -1) + + loc_y = torch.sum(grid_y*prob, dim=1) + loc_x = torch.sum(grid_x*prob, dim=1) + # these are B + + return loc_y, loc_x + +def reduce_masked_mean(x, mask, dim=None, keepdim=False): + # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting + # returns shape-1 + # axis can be a list of axes + for (a,b) in zip(x.size(), mask.size()): + # if not b==1: + assert(a==b) # some shape mismatch! + # assert(x.size() == mask.size()) + prod = x*mask + if dim is None: + numer = torch.sum(prod) + denom = EPS+torch.sum(mask) + else: + numer = torch.sum(prod, dim=dim, keepdim=keepdim) + denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim) + + mean = numer/denom + return mean + +def reduce_masked_median(x, mask, keep_batch=False): + # x and mask are the same shape + assert(x.size() == mask.size()) + device = x.device + + B = list(x.shape)[0] + x = x.detach().cpu().numpy() + mask = mask.detach().cpu().numpy() + + if keep_batch: + x = np.reshape(x, [B, -1]) + mask = np.reshape(mask, [B, -1]) + meds = np.zeros([B], np.float32) + for b in list(range(B)): + xb = x[b] + mb = mask[b] + if np.sum(mb) > 0: + xb = xb[mb > 0] + meds[b] = np.median(xb) + else: + meds[b] = np.nan + meds = torch.from_numpy(meds).to(device) + return meds.float() + else: + x = np.reshape(x, [-1]) + mask = np.reshape(mask, [-1]) + if np.sum(mask) > 0: + x = x[mask > 0] + med = np.median(x) + else: + med = np.nan + med = np.array([med], np.float32) + med = torch.from_numpy(med).to(device) + return med.float() + +def pack_seqdim(tensor, B): + shapelist = list(tensor.shape) + B_, S = shapelist[:2] + assert(B==B_) + otherdims = shapelist[2:] + tensor = torch.reshape(tensor, [B*S]+otherdims) + return tensor + +def unpack_seqdim(tensor, B): + shapelist = list(tensor.shape) + BS = shapelist[0] + assert(BS%B==0) + otherdims = shapelist[1:] + S = int(BS/B) + tensor = torch.reshape(tensor, [B,S]+otherdims) + return tensor + +def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False): + # returns a meshgrid sized B x Y x X + + grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device)) + grid_y = torch.reshape(grid_y, [1, Y, 1]) + grid_y = grid_y.repeat(B, 1, X) + + grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device)) + grid_x = torch.reshape(grid_x, [1, 1, X]) + grid_x = grid_x.repeat(B, Y, 1) + + if norm: + grid_y, grid_x = normalize_grid2d( + grid_y, grid_x, Y, X) + + if stack: + # note we stack in xy order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + if on_chans: + grid = torch.stack([grid_x, grid_y], dim=1) + else: + grid = torch.stack([grid_x, grid_y], dim=-1) + return grid + else: + return grid_y, grid_x + +def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'): + # returns a meshgrid sized B x Z x Y x X + + grid_z = torch.linspace(0.0, Z-1, Z, device=device) + grid_z = torch.reshape(grid_z, [1, Z, 1, 1]) + grid_z = grid_z.repeat(B, 1, Y, X) + + grid_y = torch.linspace(0.0, Y-1, Y, device=device) + grid_y = torch.reshape(grid_y, [1, 1, Y, 1]) + grid_y = grid_y.repeat(B, Z, 1, X) + + grid_x = torch.linspace(0.0, X-1, X, device=device) + grid_x = torch.reshape(grid_x, [1, 1, 1, X]) + grid_x = grid_x.repeat(B, Z, Y, 1) + + # if cuda: + # grid_z = grid_z.cuda() + # grid_y = grid_y.cuda() + # grid_x = grid_x.cuda() + + if norm: + grid_z, grid_y, grid_x = normalize_grid3d( + grid_z, grid_y, grid_x, Z, Y, X) + + if stack: + # note we stack in xyz order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + grid = torch.stack([grid_x, grid_y, grid_z], dim=-1) + return grid + else: + return grid_z, grid_y, grid_x + +def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True): + # make things in [-1,1] + grid_y = 2.0*(grid_y / float(Y-1)) - 1.0 + grid_x = 2.0*(grid_x / float(X-1)) - 1.0 + + if clamp_extreme: + grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) + grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) + + return grid_y, grid_x + +def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True): + # make things in [-1,1] + grid_z = 2.0*(grid_z / float(Z-1)) - 1.0 + grid_y = 2.0*(grid_y / float(Y-1)) - 1.0 + grid_x = 2.0*(grid_x / float(X-1)) - 1.0 + + if clamp_extreme: + grid_z = torch.clamp(grid_z, min=-2.0, max=2.0) + grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) + grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) + + return grid_z, grid_y, grid_x + +def gridcloud2d(B, Y, X, norm=False, device='cuda'): + # we want to sample for each location in the grid + grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device) + x = torch.reshape(grid_x, [B, -1]) + y = torch.reshape(grid_y, [B, -1]) + # these are B x N + xy = torch.stack([x, y], dim=2) + # this is B x N x 2 + return xy + +def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'): + # we want to sample for each location in the grid + grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device) + x = torch.reshape(grid_x, [B, -1]) + y = torch.reshape(grid_y, [B, -1]) + z = torch.reshape(grid_z, [B, -1]) + # these are B x N + xyz = torch.stack([x, y, z], dim=2) + # this is B x N x 3 + return xyz + +import re +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def normalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin / float(H) + ymax = ymax / float(H) + xmin = xmin / float(W) + xmax = xmax / float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin * float(H) + ymax = ymax * float(H) + xmin = xmin * float(W) + xmax = xmax * float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_box2d(box2d, H, W): + return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def normalize_box2d(box2d, H, W): + return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False): + C = channels + xy_grid = gridcloud2d(C, kernel_size, kernel_size) # C x N x 2 + + mean = (kernel_size - 1)/2.0 + variance = sigma**2.0 + + gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N + gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3 + kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True) + + gaussian_kernel = gaussian_kernel / kernel_sum # normalize + + if mid_one: + # normalize so that the middle element is 1 + maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1) + gaussian_kernel = gaussian_kernel / maxval + + return gaussian_kernel + +def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False): + B, C, Z, X = input.shape + kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one) + if reflect_pad: + pad = (kernel_size - 1)//2 + out = F.pad(input, (pad, pad, pad, pad), mode='reflect') + out = F.conv2d(out, kernel, padding=0, groups=C) + else: + out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C) + return out + +def gradient2d(x, absolute=False, square=False, return_sum=False): + # x should be B x C x H x W + dh = x[:, :, 1:, :] - x[:, :, :-1, :] + dw = x[:, :, :, 1:] - x[:, :, :, :-1] + + zeros = torch.zeros_like(x) + zero_h = zeros[:, :, 0:1, :] + zero_w = zeros[:, :, :, 0:1] + dh = torch.cat([dh, zero_h], axis=2) + dw = torch.cat([dw, zero_w], axis=3) + if absolute: + dh = torch.abs(dh) + dw = torch.abs(dw) + if square: + dh = dh ** 2 + dw = dw ** 2 + if return_sum: + return dh+dw + else: + return dh, dw \ No newline at end of file diff --git a/monst3r/dust3r/utils/po_utils/geom.py b/monst3r/dust3r/utils/po_utils/geom.py new file mode 100644 index 0000000..76345e7 --- /dev/null +++ b/monst3r/dust3r/utils/po_utils/geom.py @@ -0,0 +1,547 @@ +import torch +import dust3r.utils.po_utils.basic +import numpy as np +import torchvision.ops as ops +from dust3r.utils.po_utils.basic import print_ + +def matmul2(mat1, mat2): + return torch.matmul(mat1, mat2) + +def matmul3(mat1, mat2, mat3): + return torch.matmul(mat1, torch.matmul(mat2, mat3)) + +def eye_3x3(B, device='cuda'): + rt = torch.eye(3, device=torch.device(device)).view(1,3,3).repeat([B, 1, 1]) + return rt + +def eye_4x4(B, device='cuda'): + rt = torch.eye(4, device=torch.device(device)).view(1,4,4).repeat([B, 1, 1]) + return rt + +def safe_inverse(a): #parallel version + B, _, _ = list(a.shape) + inv = a.clone() + r_transpose = a[:, :3, :3].transpose(1,2) #inverse of rotation matrix + + inv[:, :3, :3] = r_transpose + inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4]) + + return inv + +def safe_inverse_single(a): + r, t = split_rt_single(a) + t = t.view(3,1) + r_transpose = r.t() + inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1) + bottom_row = a[3:4, :] # this is [0, 0, 0, 1] + # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4) + inv = torch.cat([inv, bottom_row], 0) + return inv + +def split_intrinsics(K): + # K is B x 3 x 3 or B x 4 x 4 + fx = K[:,0,0] + fy = K[:,1,1] + x0 = K[:,0,2] + y0 = K[:,1,2] + return fx, fy, x0, y0 + +def apply_pix_T_cam(pix_T_cam, xyz): + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + B, N, C = list(xyz.shape) + assert(C==3) + + x, y, z = torch.unbind(xyz, axis=-1) + + fx = torch.reshape(fx, [B, 1]) + fy = torch.reshape(fy, [B, 1]) + x0 = torch.reshape(x0, [B, 1]) + y0 = torch.reshape(y0, [B, 1]) + + EPS = 1e-4 + z = torch.clamp(z, min=EPS) + x = (x*fx)/(z)+x0 + y = (y*fy)/(z)+y0 + xy = torch.stack([x, y], axis=-1) + return xy + +def apply_pix_T_cam_py(pix_T_cam, xyz): + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + B, N, C = list(xyz.shape) + assert(C==3) + + x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2] + + fx = np.reshape(fx, [B, 1]) + fy = np.reshape(fy, [B, 1]) + x0 = np.reshape(x0, [B, 1]) + y0 = np.reshape(y0, [B, 1]) + + EPS = 1e-4 + z = np.clip(z, EPS, None) + x = (x*fx)/(z)+x0 + y = (y*fy)/(z)+y0 + xy = np.stack([x, y], axis=-1) + return xy + +def get_camM_T_camXs(origin_T_camXs, ind=0): + B, S = list(origin_T_camXs.shape)[0:2] + camM_T_camXs = torch.zeros_like(origin_T_camXs) + for b in list(range(B)): + camM_T_origin = safe_inverse_single(origin_T_camXs[b,ind]) + for s in list(range(S)): + camM_T_camXs[b,s] = torch.matmul(camM_T_origin, origin_T_camXs[b,s]) + return camM_T_camXs + +def apply_4x4(RT, xyz): + B, N, _ = list(xyz.shape) + ones = torch.ones_like(xyz[:,:,0:1]) + xyz1 = torch.cat([xyz, ones], 2) + xyz1_t = torch.transpose(xyz1, 1, 2) + # this is B x 4 x N + xyz2_t = torch.matmul(RT, xyz1_t) + xyz2 = torch.transpose(xyz2_t, 1, 2) + xyz2 = xyz2[:,:,:3] + return xyz2 + +def apply_4x4_py(RT, xyz): + # print('RT', RT.shape) + B, N, _ = list(xyz.shape) + ones = np.ones_like(xyz[:,:,0:1]) + xyz1 = np.concatenate([xyz, ones], 2) + # print('xyz1', xyz1.shape) + xyz1_t = xyz1.transpose(0,2,1) + # print('xyz1_t', xyz1_t.shape) + # this is B x 4 x N + xyz2_t = np.matmul(RT, xyz1_t) + # print('xyz2_t', xyz2_t.shape) + xyz2 = xyz2_t.transpose(0,2,1) + # print('xyz2', xyz2.shape) + xyz2 = xyz2[:,:,:3] + return xyz2 + +def apply_3x3(RT, xy): + B, N, _ = list(xy.shape) + ones = torch.ones_like(xy[:,:,0:1]) + xy1 = torch.cat([xy, ones], 2) + xy1_t = torch.transpose(xy1, 1, 2) + # this is B x 4 x N + xy2_t = torch.matmul(RT, xy1_t) + xy2 = torch.transpose(xy2_t, 1, 2) + xy2 = xy2[:,:,:2] + return xy2 + +def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts): + ''' + Start with the center of the polygon at ctr_x, ctr_y, + Then creates the polygon by sampling points on a circle around the center. + Random noise is added by varying the angular spacing between sequential points, + and by varying the radial distance of each point from the centre. + + Params: + ctr_x, ctr_y - coordinates of the "centre" of the polygon + avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude. + irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts] + spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r] +pp num_verts + + Returns: + np.array [num_verts, 2] - CCW order. + ''' + # spikiness + spikiness = np.clip(spikiness, 0, 1) * avg_r + + # generate n angle steps + irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts + lower = (2*np.pi / num_verts) - irregularity + upper = (2*np.pi / num_verts) + irregularity + + # angle steps + angle_steps = np.random.uniform(lower, upper, num_verts) + sc = (2 * np.pi) / angle_steps.sum() + angle_steps *= sc + + # get all radii + angle = np.random.uniform(0, 2*np.pi) + radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r) + + # compute all points + points = [] + for i in range(num_verts): + x = ctr_x + radii[i] * np.cos(angle) + y = ctr_y + radii[i] * np.sin(angle) + points.append([x, y]) + angle += angle_steps[i] + + return np.array(points).astype(int) + + +def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05): + ''' + Params: + rot_min: rotation amount min + rot_max: rotation amount max + + tx_min: translation x min + tx_max: translation x max + + ty_min: translation y min + ty_max: translation y max + + sx_min: scaling x min + sx_max: scaling x max + + sy_min: scaling y min + sy_max: scaling y max + + shx_min: shear x min + shx_max: shear x max + + shy_min: shear y min + shy_max: shear y max + + Returns: + transformation matrix: (B, 3, 3) + ''' + # rotation + if rot_max - rot_min != 0: + rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B) + rot_amount = np.pi/180.0*rot_amount + else: + rot_amount = rot_min + rotation = np.zeros((B, 3, 3)) # B, 3, 3 + rotation[:, 2, 2] = 1 + rotation[:, 0, 0] = np.cos(rot_amount) + rotation[:, 0, 1] = -np.sin(rot_amount) + rotation[:, 1, 0] = np.sin(rot_amount) + rotation[:, 1, 1] = np.cos(rot_amount) + + # translation + translation = np.zeros((B, 3, 3)) # B, 3, 3 + translation[:, [0,1,2], [0,1,2]] = 1 + if (tx_max - tx_min) > 0: + trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B) + translation[:, 0, 2] = trans_x + # else: + # translation[:, 0, 2] = tx_max + if ty_max - ty_min != 0: + trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B) + translation[:, 1, 2] = trans_y + # else: + # translation[:, 1, 2] = ty_max + + # scaling + scaling = np.zeros((B, 3, 3)) # B, 3, 3 + scaling[:, [0,1,2], [0,1,2]] = 1 + if (sx_max - sx_min) > 0: + scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B) + scaling[:, 0, 0] = scale_x + # else: + # scaling[:, 0, 0] = sx_max + if (sy_max - sy_min) > 0: + scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B) + scaling[:, 1, 1] = scale_y + # else: + # scaling[:, 1, 1] = sy_max + + # shear + shear = np.zeros((B, 3, 3)) # B, 3, 3 + shear[:, [0,1,2], [0,1,2]] = 1 + if (shx_max - shx_min) > 0: + shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B) + shear[:, 0, 1] = shear_x + # else: + # shear[:, 0, 1] = shx_max + if (shy_max - shy_min) > 0: + shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B) + shear[:, 1, 0] = shear_y + # else: + # shear[:, 1, 0] = shy_max + + # compose all those + rt = np.einsum("ijk,ikl->ijl", rotation, translation) + ss = np.einsum("ijk,ikl->ijl", scaling, shear) + trans = np.einsum("ijk,ikl->ijl", rt, ss) + + return trans + +def get_centroid_from_box2d(box2d): + ymin = box2d[:,0] + xmin = box2d[:,1] + ymax = box2d[:,2] + xmax = box2d[:,3] + x = (xmin+xmax)/2.0 + y = (ymin+ymax)/2.0 + return y, x + +def normalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin / float(H) + ymax = ymax / float(H) + xmin = xmin / float(W) + xmax = xmax / float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin * float(H) + ymax = ymax * float(H) + xmin = xmin * float(W) + xmax = xmax * float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_box2d(box2d, H, W): + return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def normalize_box2d(box2d, H, W): + return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def get_size_from_box2d(box2d): + ymin = box2d[:,0] + xmin = box2d[:,1] + ymax = box2d[:,2] + xmax = box2d[:,3] + height = ymax-ymin + width = xmax-xmin + return height, width + +def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False): + B, C, H, W = im.shape + B2, N, D = boxlist.shape + assert(B==B2) + assert(D==4) + # PH, PW is the size to resize to + + # output is B,N,C,PH,PW + + # pt wants xy xy, unnormalized + if boxlist_is_normalized: + boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W) + else: + boxlist_unnorm = boxlist + + ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2) + # boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1) + boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2) + # we want a B-len list of K x 4 arrays + + # print('im', im.shape) + # print('boxlist', boxlist.shape) + # print('boxlist_pt', boxlist_pt.shape) + + # boxlist_pt = list(boxlist_pt.unbind(0)) + + crops = [] + for b in range(B): + crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW)) + crops.append(crops_b) + # # crops = im + + # print('crops', crops.shape) + # crops = crops.reshape(B,N,C,PH,PW) + + + # crops = [] + # for b in range(B): + # crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW)) + # print('crop_b', crop_b.shape) + # crops.append(crop_b) + crops = torch.stack(crops, dim=0) + + # print('crops', crops.shape) + # boxlist_list = boxlist_pt.unbind(0) + # print('rgb_crop', rgb_crop.shape) + + return crops + + +# def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True): +# # cy,cx are both B,N +# ymin = cy - h/2 +# ymax = cy + h/2 +# xmin = cx - w/2 +# xmax = cx + w/2 + +# box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) +# if clip: +# box = torch.clamp(box, 0, 1) +# return box + + +def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False): + # cy,cx are the same shape + ymin = cy - h/2 + ymax = cy + h/2 + xmin = cx - w/2 + xmax = cx + w/2 + + # if clip: + # ymin = torch.clamp(ymin, 0, H-1) + # ymax = torch.clamp(ymax, 0, H-1) + # xmin = torch.clamp(xmin, 0, W-1) + # xmax = torch.clamp(xmax, 0, W-1) + + box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) + return box + + +def get_box2d_from_mask(mask, normalize=False): + # mask is B, 1, H, W + + B, C, H, W = mask.shape + assert(C==1) + xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2 + + box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device) + for b in range(B): + xy_b = xy[b] # H*W, 2 + mask_b = mask[b].reshape(H*W) + xy_ = xy_b[mask_b > 0] + x_ = xy_[:,0] + y_ = xy_[:,1] + ymin = torch.min(y_) + ymax = torch.max(y_) + xmin = torch.min(x_) + xmax = torch.max(x_) + box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0) + if normalize: + box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1) + return box + +def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0): + # box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords + # ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) + # H, W is the original size of the image + # mult_padding is relative to object size in pixels + + # i assume we're rendering an image the same size as the original (H, W) + + if not mult_padding==1.0: + y, x = get_centroid_from_box2d(box2d) + h, w = get_size_from_box2d(box2d) + box2d = get_box2d_from_centroid_and_size( + y, x, h*mult_padding, w*mult_padding, clip=False) + + if use_image_aspect_ratio: + h, w = get_size_from_box2d(box2d) + y, x = get_centroid_from_box2d(box2d) + + # note h,w are relative right now + # we need to undo this, to see the real ratio + + h = h*float(H) + w = w*float(W) + box_ratio = h/w + im_ratio = H/float(W) + + # print('box_ratio:', box_ratio) + # print('im_ratio:', im_ratio) + + if box_ratio >= im_ratio: + w = h/im_ratio + # print('setting w:', h/im_ratio) + else: + h = w*im_ratio + # print('setting h:', w*im_ratio) + + box2d = get_box2d_from_centroid_and_size( + y, x, h/float(H), w/float(W), clip=False) + + assert(h > 1e-4) + assert(w > 1e-4) + + ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # the topleft of the new image will now have a different offset from the center of projection + + new_x0 = x0 - xmin*W + new_y0 = y0 - ymin*H + + pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0) + # this alone will give me an image in original resolution, + # with its topleft at the box corner + + box_h, box_w = get_size_from_box2d(box2d) + # these are normalized, and shaped B. (e.g., [0.4], [0.3]) + + # we are going to scale the image by the inverse of this, + # since we are zooming into this area + + sy = 1./box_h + sx = 1./box_w + + pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy) + return pix_T_cam, box2d + +def pixels2camera(x,y,z,fx,fy,x0,y0): + # x and y are locations in pixel coordinates, z is a depth in meters + # they can be images or pointclouds + # fx, fy, x0, y0 are camera intrinsics + # returns xyz, sized B x N x 3 + + B = x.shape[0] + + fx = torch.reshape(fx, [B,1]) + fy = torch.reshape(fy, [B,1]) + x0 = torch.reshape(x0, [B,1]) + y0 = torch.reshape(y0, [B,1]) + + x = torch.reshape(x, [B,-1]) + y = torch.reshape(y, [B,-1]) + z = torch.reshape(z, [B,-1]) + + # unproject + x = (z/fx)*(x-x0) + y = (z/fy)*(y-y0) + + xyz = torch.stack([x,y,z], dim=2) + # B x N x 3 + return xyz + +def camera2pixels(xyz, pix_T_cam): + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + x, y, z = torch.unbind(xyz, dim=-1) + B = list(z.shape)[0] + + fx = torch.reshape(fx, [B,1]) + fy = torch.reshape(fy, [B,1]) + x0 = torch.reshape(x0, [B,1]) + y0 = torch.reshape(y0, [B,1]) + x = torch.reshape(x, [B,-1]) + y = torch.reshape(y, [B,-1]) + z = torch.reshape(z, [B,-1]) + + EPS = 1e-4 + z = torch.clamp(z, min=EPS) + x = (x*fx)/z + x0 + y = (y*fy)/z + y0 + xy = torch.stack([x, y], dim=-1) + return xy + +def depth2pointcloud(z, pix_T_cam): + B, C, H, W = list(z.shape) + device = z.device + y, x = utils.basic.meshgrid2d(B, H, W, device=device) + z = torch.reshape(z, [B, H, W]) + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + xyz = pixels2camera(x, y, z, fx, fy, x0, y0) + return xyz \ No newline at end of file diff --git a/monst3r/dust3r/utils/po_utils/improc.py b/monst3r/dust3r/utils/po_utils/improc.py new file mode 100644 index 0000000..ca52800 --- /dev/null +++ b/monst3r/dust3r/utils/po_utils/improc.py @@ -0,0 +1,1526 @@ +import torch +import numpy as np +import dust3r.utils.po_utils.basic +from sklearn.decomposition import PCA +from matplotlib import cm +import matplotlib.pyplot as plt +import cv2 +import torch.nn.functional as F +import torchvision +EPS = 1e-6 + +from skimage.color import ( + rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb, + rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb) + +def _convert(input_, type_): + return { + 'float': input_.float(), + 'double': input_.double(), + }.get(type_, input_) + +def _generic_transform_sk_3d(transform, in_type='', out_type=''): + def apply_transform_individual(input_): + device = input_.device + input_ = input_.cpu() + input_ = _convert(input_, in_type) + + input_ = input_.permute(1, 2, 0).detach().numpy() + transformed = transform(input_) + output = torch.from_numpy(transformed).float().permute(2, 0, 1) + output = _convert(output, out_type) + return output.to(device) + + def apply_transform(input_): + to_stack = [] + for image in input_: + to_stack.append(apply_transform_individual(image)) + return torch.stack(to_stack) + return apply_transform + +hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb) + +def preprocess_color_tf(x): + import tensorflow as tf + return tf.cast(x,tf.float32) * 1./255 - 0.5 + +def preprocess_color(x): + if isinstance(x, np.ndarray): + return x.astype(np.float32) * 1./255 - 0.5 + else: + return x.float() * 1./255 - 0.5 + +def pca_embed(emb, keep, valid=None): + ## emb -- [S,H/2,W/2,C] + ## keep is the number of principal components to keep + ## Helper function for reduce_emb. + emb = emb + EPS + #emb is B x C x H x W + emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C + + if valid: + valid = valid.cpu().detach().numpy().reshape((H*W)) + + emb_reduced = list() + + B, H, W, C = np.shape(emb) + for img in emb: + if np.isnan(img).any(): + emb_reduced.append(np.zeros([H, W, keep])) + continue + + pixels_kd = np.reshape(img, (H*W, C)) + + if valid: + pixels_kd_pca = pixels_kd[valid] + else: + pixels_kd_pca = pixels_kd + + P = PCA(keep) + P.fit(pixels_kd_pca) + + if valid: + pixels3d = P.transform(pixels_kd)*valid + else: + pixels3d = P.transform(pixels_kd) + + out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32) + if np.isnan(out_img).any(): + emb_reduced.append(np.zeros([H, W, keep])) + continue + + emb_reduced.append(out_img) + + emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32) + + return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2) + +def pca_embed_together(emb, keep): + ## emb -- [S,H/2,W/2,C] + ## keep is the number of principal components to keep + ## Helper function for reduce_emb. + emb = emb + EPS + #emb is B x C x H x W + emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C + + B, H, W, C = np.shape(emb) + if np.isnan(emb).any(): + return torch.zeros(B, keep, H, W) + + pixelskd = np.reshape(emb, (B*H*W, C)) + P = PCA(keep) + P.fit(pixelskd) + pixels3d = P.transform(pixelskd) + out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32) + + if np.isnan(out_img).any(): + return torch.zeros(B, keep, H, W) + + return torch.from_numpy(out_img).permute(0, 3, 1, 2) + +def reduce_emb(emb, valid=None, inbound=None, together=False): + ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2] + ## Reduce number of chans to 3 with PCA. For vis. + # S,H,W,C = emb.shape.as_list() + S, C, H, W = list(emb.size()) + keep = 3 + + if together: + reduced_emb = pca_embed_together(emb, keep) + else: + reduced_emb = pca_embed(emb, keep, valid) #not im + + reduced_emb = utils.basic.normalize(reduced_emb) - 0.5 + if inbound is not None: + emb_inbound = emb*inbound + else: + emb_inbound = None + + return reduced_emb, emb_inbound + +def get_feat_pca(feat, valid=None): + B, C, D, W = list(feat.size()) + # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function. + + pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True) + # pca is B x 3 x W x D + return pca + +def gif_and_tile(ims, just_gif=False): + S = len(ims) + # each im is B x H x W x C + # i want a gif in the left, and the tiled frames on the right + # for the gif tool, this means making a B x S x H x W tensor + # where the leftmost part is sequential and the rest is tiled + gif = torch.stack(ims, dim=1) + if just_gif: + return gif + til = torch.cat(ims, dim=2) + til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1) + im = torch.cat([gif, til], dim=3) + return im + +def back2color(i, blacken_zeros=False): + if blacken_zeros: + const = torch.tensor([-0.5]) + i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i) + return back2color(i) + else: + return ((i+0.5)*255).type(torch.ByteTensor) + +def convert_occ_to_height(occ, reduce_axis=3): + B, C, D, H, W = list(occ.shape) + assert(C==1) + # note that height increases DOWNWARD in the tensor + # (like pixel/camera coordinates) + + G = list(occ.shape)[reduce_axis] + values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device) + if reduce_axis==2: + # fro view + values = values.view(1, 1, G, 1, 1) + elif reduce_axis==3: + # top view + values = values.view(1, 1, 1, G, 1) + elif reduce_axis==4: + # lateral view + values = values.view(1, 1, 1, 1, G) + else: + assert(False) # you have to reduce one of the spatial dims (2-4) + values = torch.max(occ*values, dim=reduce_axis)[0]/float(G) + # values = values.view([B, C, D, W]) + return values + +def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False): + # xy is B x N x 2, containing float x and y coordinates of N things + # grid_xs and grid_ys are B x N x Y x X + + B, N, Y, X = list(grid_xs.shape) + + mu_x = xy[:,:,0].clone() + mu_y = xy[:,:,1].clone() + + x_valid = (mu_x>-0.5) & (mu_x-0.5) & (mu_y 0.5).float() + return prior + +def seq2color(im, norm=True, colormap='coolwarm'): + B, S, H, W = list(im.shape) + # S is sequential + + # prep a mask of the valid pixels, so we can blacken the invalids later + mask = torch.max(im, dim=1, keepdim=True)[0] + + # turn the S dim into an explicit sequence + coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S) + + # # increase the spacing from the center + # coeffs[:int(S/2)] -= 2.0 + # coeffs[int(S/2)+1:] += 2.0 + + coeffs = torch.from_numpy(coeffs).float().cuda() + coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W) + # scale each channel by the right coeff + im = im * coeffs + # now im is in [1/S, 1], except for the invalid parts which are 0 + # keep the highest valid coeff at each pixel + im = torch.max(im, dim=1, keepdim=True)[0] + + out = [] + for b in range(B): + im_ = im[b] + # move channels out to last dim_ + im_ = im_.detach().cpu().numpy() + im_ = np.squeeze(im_) + # im_ is H x W + if colormap=='coolwarm': + im_ = cm.coolwarm(im_)[:, :, :3] + elif colormap=='PiYG': + im_ = cm.PiYG(im_)[:, :, :3] + elif colormap=='winter': + im_ = cm.winter(im_)[:, :, :3] + elif colormap=='spring': + im_ = cm.spring(im_)[:, :, :3] + elif colormap=='onediff': + im_ = np.reshape(im_, (-1)) + im0_ = cm.spring(im_)[:, :3] + im1_ = cm.winter(im_)[:, :3] + im1_[im_==1/float(S)] = im0_[im_==1/float(S)] + im_ = np.reshape(im1_, (H, W, 3)) + else: + assert(False) # invalid colormap + # move channels into dim 0 + im_ = np.transpose(im_, [2, 0, 1]) + im_ = torch.from_numpy(im_).float().cuda() + out.append(im_) + out = torch.stack(out, dim=0) + + # blacken the invalid pixels, instead of using the 0-color + out = out*mask + # out = out*255.0 + + # put it in [-0.5, 0.5] + out = out - 0.5 + + return out + +def colorize(d): + # this is actually just grayscale right now + + if d.ndim==2: + d = d.unsqueeze(dim=0) + else: + assert(d.ndim==3) + + # color_map = cm.get_cmap('plasma') + color_map = cm.get_cmap('inferno') + # S1, D = traj.shape + + # print('d1', d.shape) + C,H,W = d.shape + assert(C==1) + d = d.reshape(-1) + d = d.detach().cpu().numpy() + # print('d2', d.shape) + color = np.array(color_map(d)) * 255 # rgba + # print('color1', color.shape) + color = np.reshape(color[:,:3], [H*W, 3]) + # print('color2', color.shape) + color = torch.from_numpy(color).permute(1,0).reshape(3,H,W) + # # gather + # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') + # if cmap=='RdBu' or cmap=='RdYlGn': + # colors = cm(np.arange(256))[:, :3] + # else: + # colors = cm.colors + # colors = np.array(colors).astype(np.float32) + # colors = np.reshape(colors, [-1, 3]) + # colors = tf.constant(colors, dtype=tf.float32) + + # value = tf.gather(colors, indices) + # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255) + + # copy to the three chans + # d = d.repeat(3, 1, 1) + return color + + +def oned2inferno(d, norm=True, do_colorize=False): + # convert a 1chan input to a 3chan image output + + # if it's just B x H x W, add a C dim + if d.ndim==3: + d = d.unsqueeze(dim=1) + # d should be B x C x H x W, where C=1 + B, C, H, W = list(d.shape) + assert(C==1) + + if norm: + d = utils.basic.normalize(d) + + if do_colorize: + rgb = torch.zeros(B, 3, H, W) + for b in list(range(B)): + rgb[b] = colorize(d[b]) + else: + rgb = d.repeat(1, 3, 1, 1)*255.0 + # rgb = (255.0*rgb).type(torch.ByteTensor) + rgb = rgb.type(torch.ByteTensor) + + # rgb = tf.cast(255.0*rgb, tf.uint8) + # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3]) + # rgb = tf.expand_dims(rgb, axis=0) + return rgb + +def oned2gray(d, norm=True): + # convert a 1chan input to a 3chan image output + + # if it's just B x H x W, add a C dim + if d.ndim==3: + d = d.unsqueeze(dim=1) + # d should be B x C x H x W, where C=1 + B, C, H, W = list(d.shape) + assert(C==1) + + if norm: + d = utils.basic.normalize(d) + + rgb = d.repeat(1,3,1,1) + rgb = (255.0*rgb).type(torch.ByteTensor) + return rgb + + +def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20): + + rgb = vis.detach().cpu().numpy()[0] + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + color = (255, 255, 255) + # print('putting frame id', frame_id) + + frame_str = utils.basic.strnum(frame_id) + + text_color_bg = (0,0,0) + font = cv2.FONT_HERSHEY_SIMPLEX + text_size, _ = cv2.getTextSize(frame_str, font, scale, 1) + text_w, text_h = text_size + cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1) + + cv2.putText( + rgb, + frame_str, + (left, top), # from left, from top + font, + scale, # font scale (float) + color, + 1) # font thickness (int) + rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) + vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + return vis + +COLORMAP_FILE = "./utils/bremm.png" +class ColorMap2d: + def __init__(self, filename=None): + self._colormap_file = filename or COLORMAP_FILE + self._img = plt.imread(self._colormap_file) + + self._height = self._img.shape[0] + self._width = self._img.shape[1] + + def __call__(self, X): + assert len(X.shape) == 2 + output = np.zeros((X.shape[0], 3)) + for i in range(X.shape[0]): + x, y = X[i, :] + xp = int((self._width-1) * x) + yp = int((self._height-1) * y) + xp = np.clip(xp, 0, self._width-1) + yp = np.clip(yp, 0, self._height-1) + output[i, :] = self._img[yp, xp] + return output + +def get_n_colors(N, sequential=False): + label_colors = [] + for ii in range(N): + if sequential: + rgb = cm.winter(ii/(N-1)) + rgb = (np.array(rgb) * 255).astype(np.uint8)[:3] + else: + rgb = np.zeros(3) + while np.sum(rgb) < 128: # ensure min brightness + rgb = np.random.randint(0,256,3) + label_colors.append(rgb) + return label_colors + +class Summ_writer(object): + def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False): + self.writer = writer + self.global_step = global_step + self.log_freq = log_freq + self.fps = fps + self.just_gif = just_gif + self.maxwidth = 10000 + self.save_this = (self.global_step % self.log_freq == 0) + self.scalar_freq = max(scalar_freq,1) + + + def summ_gif(self, name, tensor, blacken_zeros=False): + # tensor should be in B x S x C x H x W + + assert tensor.dtype in {torch.uint8,torch.float32} + shape = list(tensor.shape) + + if tensor.dtype == torch.float32: + tensor = back2color(tensor, blacken_zeros=blacken_zeros) + + video_to_write = tensor[0:1] + + S = video_to_write.shape[1] + if S==1: + # video_to_write is 1 x 1 x C x H x W + self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step) + else: + self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step) + + return video_to_write + + def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1): + B, C, H, W = list(rgb.shape) + assert(C==3) + B2, N, D = list(boxlist.shape) + assert(B2==B) + assert(D==4) # ymin, xmin, ymax, xmax + + rgb = back2color(rgb) + if scores is None: + scores = torch.ones(B2, N).float() + if tids is None: + tids = torch.arange(N).reshape(1,N).repeat(B2,N).long() + # tids = torch.zeros(B2, N).long() + out = self.draw_boxlist2d_on_image_py( + rgb[0].cpu().detach().numpy(), + boxlist[0].cpu().detach().numpy(), + scores[0].cpu().detach().numpy(), + tids[0].cpu().detach().numpy(), + linewidth=linewidth) + out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1) + out = torch.unsqueeze(out, dim=0) + out = preprocess_color(out) + out = torch.reshape(out, [1, C, H, W]) + return out + + def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1): + # all inputs are numpy tensors + # rgb is H x W x 3 + # boxlist is N x 4 + # scores is N + # tids is N + + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + + rgb = rgb.astype(np.uint8).copy() + + + H, W, C = rgb.shape + assert(C==3) + N, D = boxlist.shape + assert(D==4) + + # color_map = cm.get_cmap('tab20') + # color_map = cm.get_cmap('set1') + color_map = cm.get_cmap('Accent') + color_map = color_map.colors + # print('color_map', color_map) + + # draw + for ind, box in enumerate(boxlist): + # box is 4 + if not np.isclose(scores[ind], 0.0): + # box = utils.geom.scale_box2d(box, H, W) + ymin, xmin, ymax, xmax = box + + # ymin, ymax = ymin*H, ymax*H + # xmin, xmax = xmin*W, xmax*W + + # print 'score = %.2f' % scores[ind] + # color_id = tids[ind] % 20 + color_id = tids[ind] + color = color_map[color_id] + color = np.array(color)*255.0 + color = color.round() + # color = color.astype(np.uint8) + # color = color[::-1] + # print('color', color) + + # print 'tid = %d; score = %.3f' % (tids[ind], scores[ind]) + + # if False: + if scores[ind] < 1.0: # not gt + cv2.putText(rgb, + # '%d (%.2f)' % (tids[ind], scores[ind]), + '%.2f' % (scores[ind]), + (int(xmin), int(ymin)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, # font size + color), + #1) # font weight + + xmin = np.clip(int(xmin), 0, W-1) + xmax = np.clip(int(xmax), 0, W-1) + ymin = np.clip(int(ymin), 0, H-1) + ymax = np.clip(int(ymax), 0, H-1) + + cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA) + + # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) + return rgb + + def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2): + B, C, H, W = list(rgb.shape) + boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth) + return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return) + + def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False): + if self.save_this: + + ims = gif_and_tile(ims, just_gif=self.just_gif) + vis = ims + + assert vis.dtype in {torch.uint8,torch.float32} + + if vis.dtype == torch.float32: + vis = back2color(vis, blacken_zeros) + + B, S, C, H, W = list(vis.shape) + + if frame_ids is not None: + assert(len(frame_ids)==S) + for s in range(S): + vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) + + if int(W) > self.maxwidth: + vis = vis[:,:,:,:self.maxwidth] + + if only_return: + return vis + else: + return self.summ_gif(name, vis, blacken_zeros) + + def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False): + if self.save_this: + assert ims.dtype in {torch.uint8,torch.float32} + + if ims.dtype == torch.float32: + ims = back2color(ims, blacken_zeros) + + #ims is B x C x H x W + vis = ims[0:1] # just the first one + B, C, H, W = list(vis.shape) + + if halfres: + vis = F.interpolate(vis, scale_factor=0.5) + + if frame_id is not None: + vis = draw_frame_id_on_vis(vis, frame_id) + + if int(W) > self.maxwidth: + vis = vis[:,:,:,:self.maxwidth] + + if only_return: + return vis + else: + return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros) + + def flow2color(self, flow, clip=50.0): + """ + :param flow: Optical flow tensor. + :return: RGB image normalized between 0 and 1. + """ + + # flow is B x C x H x W + + B, C, H, W = list(flow.size()) + + flow = flow.clone().detach() + + abs_image = torch.abs(flow) + flow_mean = abs_image.mean(dim=[1,2,3]) + flow_std = abs_image.std(dim=[1,2,3]) + + if clip: + flow = torch.clamp(flow, -clip, clip)/clip + else: + # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2) + flow_max = flow_mean + flow_std*2 + 1e-10 + for b in range(B): + flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1) + + radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W + radius_clipped = torch.clamp(radius, 0.0, 1.0) + + angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi #B x 1 x H x W + + hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0) + saturation = torch.ones_like(hue) * 0.75 + value = radius_clipped + hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W + + #flow = tf.image.hsv_to_rgb(hsv) + flow = hsv_to_rgb(hsv) + flow = (flow*255.0).type(torch.ByteTensor) + return flow + + def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None): + # flow is B x C x D x W + if self.save_this: + return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id) + else: + return None + + def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False): + if self.save_this: + if bev: + B, C, H, _, W = list(ims[0].shape) + if reduce_max: + ims = [torch.max(im, dim=3)[0] for im in ims] + else: + ims = [torch.mean(im, dim=3) for im in ims] + elif fro: + B, C, _, H, W = list(ims[0].shape) + if reduce_max: + ims = [torch.max(im, dim=2)[0] for im in ims] + else: + ims = [torch.mean(im, dim=2) for im in ims] + + + if len(ims) != 1: # sequence + im = gif_and_tile(ims, just_gif=self.just_gif) + else: + im = torch.stack(ims, dim=1) # single frame + + B, S, C, H, W = list(im.shape) + + if logvis and max_val: + max_val = np.log(max_val) + im = torch.log(torch.clamp(im, 0)+1.0) + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + elif max_val: + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + + if norm: + # normalize before oned2inferno, + # so that the ranges are similar within B across S + im = utils.basic.normalize(im) + + im = im.view(B*S, C, H, W) + vis = oned2inferno(im, norm=norm, do_colorize=do_colorize) + vis = vis.view(B, S, 3, H, W) + + if frame_ids is not None: + assert(len(frame_ids)==S) + for s in range(S): + vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) + + if W > self.maxwidth: + vis = vis[...,:self.maxwidth] + + if only_return: + return vis + else: + self.summ_gif(name, vis) + + def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, only_return=False): + if self.save_this: + + if bev: + B, C, H, _, W = list(im.shape) + if max_along_y: + im = torch.max(im, dim=3)[0] + else: + im = torch.mean(im, dim=3) + elif fro: + B, C, _, H, W = list(im.shape) + if max_along_y: + im = torch.max(im, dim=2)[0] + else: + im = torch.mean(im, dim=2) + else: + B, C, H, W = list(im.shape) + + im = im[0:1] # just the first one + assert(C==1) + + if logvis and max_val: + max_val = np.log(max_val) + im = torch.log(im) + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + elif max_val: + im = torch.clamp(im, 0, max_val)/max_val + norm = False + + vis = oned2inferno(im, norm=norm) + if W > self.maxwidth: + vis = vis[...,:self.maxwidth] + return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return) + + def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None): + if self.save_this: + if valids is not None: + valids = torch.stack(valids, dim=1) + + feats = torch.stack(feats, dim=1) + # feats leads with B x S x C + + if feats.ndim==6: + + # feats is B x S x C x D x H x W + if fro: + reduce_dim = 3 + else: + reduce_dim = 4 + + if valids is None: + feats = torch.mean(feats, dim=reduce_dim) + else: + valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1) + feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim) + + B, S, C, D, W = list(feats.size()) + + if not pca: + # feats leads with B x S x C + feats = torch.mean(torch.abs(feats), dim=2, keepdims=True) + # feats leads with B x S x 1 + feats = torch.unbind(feats, dim=1) + return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids) + + else: + __p = lambda x: utils.basic.pack_seqdim(x, B) + __u = lambda x: utils.basic.unpack_seqdim(x, B) + + feats_ = __p(feats) + + if valids is None: + feats_pca_ = get_feat_pca(feats_) + else: + valids_ = __p(valids) + feats_pca_ = get_feat_pca(feats_, valids) + + feats_pca = __u(feats_pca_) + + return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids) + + def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None): + if self.save_this: + if feat.ndim==5: # B x C x D x H x W + + if bev: + reduce_axis = 3 + elif fro: + reduce_axis = 2 + else: + # default to bev + reduce_axis = 3 + + if valid is None: + feat = torch.mean(feat, dim=reduce_axis) + else: + valid = valid.repeat(1, feat.size()[1], 1, 1, 1) + feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis) + + B, C, D, W = list(feat.shape) + + if not pca: + feat = torch.mean(torch.abs(feat), dim=1, keepdims=True) + # feat is B x 1 x D x W + return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id) + else: + feat_pca = get_feat_pca(feat, valid) + return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id) + + def summ_scalar(self, name, value): + if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()): + value = value.detach().cpu().numpy() + if not np.isnan(value): + if (self.log_freq == 1): + self.writer.add_scalar(name, value, global_step=self.global_step) + elif self.save_this or np.mod(self.global_step, self.scalar_freq)==0: + self.writer.add_scalar(name, value, global_step=self.global_step) + + def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None): + if not self.save_this: + return + + B,H,W = seg.shape + + if label_colors is None: + custom_label_colors = False + # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True) + label_colors = cm.get_cmap(colormap).colors + label_colors = [[int(i*255) for i in l] for l in label_colors] + else: + custom_label_colors = True + # label_colors = matplotlib.cm.get_cmap(colormap).colors + # label_colors = [[int(i*255) for i in l] for l in label_colors] + # print('label_colors', label_colors) + + # label_colors = [ + # (0, 0, 0), # None + # (70, 70, 70), # Buildings + # (190, 153, 153), # Fences + # (72, 0, 90), # Other + # (220, 20, 60), # Pedestrians + # (153, 153, 153), # Poles + # (157, 234, 50), # RoadLines + # (128, 64, 128), # Roads + # (244, 35, 232), # Sidewalks + # (107, 142, 35), # Vegetation + # (0, 0, 255), # Vehicles + # (102, 102, 156), # Walls + # (220, 220, 0) # TrafficSigns + # ] + + r = torch.zeros_like(seg,dtype=torch.uint8) + g = torch.zeros_like(seg,dtype=torch.uint8) + b = torch.zeros_like(seg,dtype=torch.uint8) + + for label in range(0,len(label_colors)): + if (not custom_label_colors):# and (N > 20): + label_ = label % 20 + else: + label_ = label + + idx = (seg == label+1) + r[idx] = label_colors[label_][0] + g[idx] = label_colors[label_][1] + b[idx] = label_colors[label_][2] + + rgb = torch.stack([r,g,b],axis=1) + return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id) + + def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, C, H, W = rgb.shape + B, S, N, D = trajs.shape + + rgb = rgb[0] # C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + rgb = rgb.astype(np.uint8).copy() + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + valid = valids[:,i] # S + + color_map = cm.get_cmap(cmap) + color = np.array(color_map(i)[:3]) * 255 # rgb + for s in range(S): + if valid[s]: + cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1) + rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0) + rgb = preprocess_color(rgb) + return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) + + def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color] + + for i in range(N): + traj = trajs[:,i] # S,2 + valid = valids[:,i] # S + + color_map = cm.get_cmap(cmap) + color = np.array(color_map(0)[:3]) * 255 # rgb + for s in range(S): + if valid[s]: + cv2.circle(rgbs_color[s], (traj[s,0], traj[s,1]), linewidth, color, -1) + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + + def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + if vals is not None: + vals = vals[0] # N + # print('vals', vals.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i].long().detach().cpu().numpy() # S, 2 + valid = valids[:,i].long().detach().cpu().numpy() # S + + # print('traj', traj.shape) + # print('valid', valid.shape) + + if vals is not None: + # val = vals[:,i].float().detach().cpu().numpy() # [] + val = vals[i].float().detach().cpu().numpy() # [] + # print('val', val.shape) + else: + val = None + + for t in range(S): + if valid[t]: + # traj_seq = traj[max(t-16,0):t+1] + traj_seq = traj[max(t-8,0):t+1] + val_seq = np.linspace(0,1,len(traj_seq)) + # if t<2: + # val_seq = np.zeros_like(val_seq) + # print('val_seq', val_seq) + # val_seq = 1.0 + # val_seq = np.arange(8)/8.0 + # val_seq = val_seq[-len(traj_seq):] + # rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + # input() + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + # vis = visibles[:,i] # S + vis = torch.ones_like(traj[:,0]) # S + valid = valids[:,i] # S + rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) + + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + + def summ_traj2ds_on_rgbs_py(self, name, trajs, rgbs_color, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + # B, S, C, H, W = rgbs.shape + B, S, N, D = trajs.shape + # assert(S==S2) + + # rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + if vals is not None: + vals = vals[0] # N + # print('vals', vals.shape) + + # rgbs_color = [] + # for rgb in rgbs: + # rgb = back2color(rgb).detach().cpu().numpy() + # rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + # rgbs_color.append(rgb) # each element 3 x H x W + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i].long().detach().cpu().numpy() # S, 2 + valid = valids[:,i].long().detach().cpu().numpy() # S + + # print('traj', traj.shape) + # print('valid', valid.shape) + + if vals is not None: + # val = vals[:,i].float().detach().cpu().numpy() # [] + val = vals[i].float().detach().cpu().numpy() # [] + # print('val', val.shape) + else: + val = None + + for t in range(S): + # if valid[t]: + # traj_seq = traj[max(t-16,0):t+1] + traj_seq = traj[max(t-8,0):t+1] + val_seq = np.linspace(0,1,len(traj_seq)) + # if t<2: + # val_seq = np.zeros_like(val_seq) + # print('val_seq', val_seq) + # val_seq = 1.0 + # val_seq = np.arange(8)/8.0 + # val_seq = val_seq[-len(traj_seq):] + # rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + # input() + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + # vis = visibles[:,i] # S + vis = torch.ones_like(traj[:,0]) # S + valid = valids[:,i] # S + rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) + + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + + def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap=None, linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + visibles = visibles[0] # S, N + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + visibles = visibles.float().detach().cpu().numpy() # S, N + valids = valids.long().detach().cpu().numpy() # S, N + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + vis = visibles[:,i] # S + valid = valids[:,i] # S + rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + vis = visibles[:,i] # S + valid = valids[:,i] # S + if valid[0]: + rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth) + + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None, only_return=False, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgb is B, C, H, W + B, C, H, W = rgb.shape + B, S, N, D = trajs.shape + + rgb = rgb[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) + else: + valids = valids[0] + + rgb_color = back2color(rgb).detach().cpu().numpy() + rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last + + # using maxdist will dampen the colors for short motions + norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N + maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy() + maxdist = None + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S, 2 + valid = valids[:,i] # S + if valid[0]==1: + traj = traj[valid>0] + rgb_color = self.draw_traj_on_image_py( + rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth) + + rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0) + rgb = preprocess_color(rgb_color) + return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) + + def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None): + # all inputs are numpy tensors + # rgb is 3 x H x W + # traj is S x 2 + + H, W, C = rgb.shape + assert(C==3) + + rgb = rgb.astype(np.uint8).copy() + + S1, D = traj.shape + assert(D==2) + + color_map = cm.get_cmap(cmap) + S1, D = traj.shape + + for s in range(S1): + if val is not None: + color = np.array(color_map(val[s])[:3]) * 255 # rgb + else: + if maxdist is not None: + val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1) + color = np.array(color_map(val)[:3]) * 255 # rgb + else: + color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb + + if show_lines and s<(S1-1): + cv2.line(rgb, + (int(traj[s,0]), int(traj[s,1])), + (int(traj[s+1,0]), int(traj[s+1,1])), + color, + linewidth, + cv2.LINE_AA) + if show_dots: + cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, np.array(color_map(1)[:3])*255, -1) + + # if maxdist is not None: + # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1) + # color = np.array(color_map(val)[:3]) * 255 # rgb + # else: + # # draw the endpoint of traj, using the next color (which may be the last color) + # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb + + # # emphasize endpoint + # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1) + + return rgb + + + + def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of H,W,3 + # traj is S,2 + H, W, C = rgbs[0].shape + assert(C==3) + + rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] + + S1, D = traj.shape + assert(D==2) + + x = int(np.clip(traj[0,0], 0, W-1)) + y = int(np.clip(traj[0,1], 0, H-1)) + color = rgbs[0][y,x] + color = (int(color[0]),int(color[1]),int(color[2])) + for s in range(S): + # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb + # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1) + cv2.polylines(rgbs[s], + [traj[:s+1]], + False, + color, + linewidth, + cv2.LINE_AA) + return rgbs + + def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of 3,H,W + # xy is N,2 + H, W, C = rgb.shape + assert(C==3) + + rgb = rgb.astype(np.uint8).copy() + + N, D = xy.shape + assert(D==2) + + + xy = xy.astype(np.float32) + xy[:,0] = np.clip(xy[:,0], 0, W-1) + xy[:,1] = np.clip(xy[:,1], 0, H-1) + xy = xy.astype(np.int32) + + + + if colors is None: + colors = get_n_colors(N) + + for n in range(N): + color = colors[n] + # print('color', color) + # color = (color[0]*255).astype(np.uint8) + color = (int(color[0]),int(color[1]),int(color[2])) + + # x = int(np.clip(xy[0,0], 0, W-1)) + # y = int(np.clip(xy[0,1], 0, H-1)) + # color_ = rgbs[0][y,x] + # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) + # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) + + cv2.circle(rgb, (xy[n,0], xy[n,1]), linewidth, color, 3) + # vis_color = int(np.squeeze(vis[s])*255) + # vis_color = (vis_color,vis_color,vis_color) + # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1) + return rgb + + def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of 3,H,W + # traj is S,2 + H, W, C = rgbs[0].shape + assert(C==3) + + rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] + + S1, D = traj.shape + assert(D==2) + + if cmap is None: + bremm = ColorMap2d() + traj_ = traj[0:1].astype(np.float32) + traj_[:,0] /= float(W) + traj_[:,1] /= float(H) + color = bremm(traj_) + # print('color', color) + color = (color[0]*255).astype(np.uint8) + # color = (int(color[0]),int(color[1]),int(color[2])) + color = (int(color[2]),int(color[1]),int(color[0])) + + for s in range(S1): + if cmap is not None: + color_map = cm.get_cmap(cmap) + # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb + color = np.array(color_map((s+1)/max(1,float(S-1)))[:3]) * 255 # rgb + # color = color.astype(np.uint8) + # color = (color[0], color[1], color[2]) + # print('color', color) + # import ipdb; ipdb.set_trace() + + cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, color, -1) + # vis_color = int(np.squeeze(vis[s])*255) + # vis_color = (vis_color,vis_color,vis_color) + # cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1) + + return rgbs + + def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None, is_g=False): + B, S, N, D = trajs_e.shape + assert(N==1) + assert(D==2) + + rgbs_vis = [] + n = 0 + pad_amount = 100 + trajs_e_py = trajs_e[0].detach().cpu().numpy() + # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun + trajs_e_py = trajs_e_py + pad_amount + + if trajs_g is not None: + trajs_g_py = trajs_g[0].detach().cpu().numpy() + trajs_g_py = trajs_g_py + pad_amount + + for s in range(S): + rgb = rgbs[0,s].detach().cpu().numpy() + # print('orig rgb', rgb.shape) + rgb = np.transpose(rgb,(1,2,0)) # H, W, 3 + + rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0))) + # print('pad rgb', rgb.shape) + H, W, C = rgb.shape + + if trajs_g is not None: + xy_g = trajs_g_py[s,n] + xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount) + xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount) + rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) + + xy_e = trajs_e_py[s,n] + xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount) + xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount) + + if show_circ: + if is_g: + rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) + else: + rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3) + + + xmin = int(xy_e[0])-pad_amount//2 + xmax = int(xy_e[0])+pad_amount//2 + ymin = int(xy_e[1])-pad_amount//2 + ymax = int(xy_e[1])+pad_amount//2 + + rgb_ = rgb[ymin:ymax, xmin:xmax] + + H_, W_ = rgb_.shape[:2] + # if np.any(rgb_.shape==0): + # input() + if H_==0 or W_==0: + import ipdb; ipdb.set_trace() + + rgb_ = rgb_.transpose(2,0,1) + rgb_ = torch.from_numpy(rgb_) + + rgbs_vis.append(rgb_) + + # nrow = int(np.sqrt(S)*(16.0/9)/2.0) + nrow = int(np.sqrt(S)*1.5) + grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0) + # print('grid_img', grid_img.shape) + return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return) + + def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False): + if self.save_this: + B, C, D, H, W = list(occ.shape) + if bev: + reduce_axes = [3] + elif fro: + reduce_axes = [2] + elif pro: + reduce_axes = [4] + for reduce_axis in reduce_axes: + height = convert_occ_to_height(occ, reduce_axis=reduce_axis) + if reduce_axis == reduce_axes[-1]: + return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) + else: + self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) + +def erode2d(im, times=1, device='cuda'): + weights2d = torch.ones(1, 1, 3, 3, device=device) + for time in range(times): + im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1) + return im + +def dilate2d(im, times=1, device='cuda', mode='square'): + weights2d = torch.ones(1, 1, 3, 3, device=device) + if mode=='cross': + weights2d[:,:,0,0] = 0.0 + weights2d[:,:,0,2] = 0.0 + weights2d[:,:,2,0] = 0.0 + weights2d[:,:,2,2] = 0.0 + for time in range(times): + im = F.conv2d(im, weights2d, padding=1).clamp(0, 1) + return im \ No newline at end of file diff --git a/monst3r/dust3r/utils/po_utils/misc.py b/monst3r/dust3r/utils/po_utils/misc.py new file mode 100644 index 0000000..bedf119 --- /dev/null +++ b/monst3r/dust3r/utils/po_utils/misc.py @@ -0,0 +1,165 @@ +import torch +import numpy as np +import math +from prettytable import PrettyTable + +def count_parameters(model): + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + param = parameter.numel() + if param > 100000: + table.add_row([name, param]) + total_params+=param + print(table) + print('total params: %.2f M' % (total_params/1000000.0)) + return total_params + +def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): + device = xy.device + dtype = xy.dtype + B, S, D = xy.shape + assert(D==2) + x = xy[:,:,0] + y = xy[:,:,1] + assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' + omega = torch.arange(C // 4, device=device) / (C // 4 - 1) + omega = 1. / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + pe = pe.reshape(B,S,C).type(dtype) + if cat_coords: + pe = torch.cat([pe, xy], dim=2) # B,N,C+2 + return pe + +class SimplePool(): + def __init__(self, pool_size, version='pt'): + self.pool_size = pool_size + self.version = version + self.items = [] + + if not (version=='pt' or version=='np'): + print('version = %s; please choose pt or np') + assert(False) # please choose pt or np + + def __len__(self): + return len(self.items) + + def mean(self, min_size=1): + if min_size=='half': + pool_size_thresh = self.pool_size/2 + else: + pool_size_thresh = min_size + + if self.version=='np': + if len(self.items) >= pool_size_thresh: + return np.sum(self.items)/float(len(self.items)) + else: + return np.nan + if self.version=='pt': + if len(self.items) >= pool_size_thresh: + return torch.sum(self.items)/float(len(self.items)) + else: + return torch.from_numpy(np.nan) + + def sample(self, with_replacement=True): + idx = np.random.randint(len(self.items)) + if with_replacement: + return self.items[idx] + else: + return self.items.pop(idx) + + def fetch(self, num=None): + if self.version=='pt': + item_array = torch.stack(self.items) + elif self.version=='np': + item_array = np.stack(self.items) + if num is not None: + # there better be some items + assert(len(self.items) >= num) + + # if there are not that many elements just return however many there are + if len(self.items) < num: + return item_array + else: + idxs = np.random.randint(len(self.items), size=num) + return item_array[idxs] + else: + return item_array + + def is_full(self): + full = len(self.items)==self.pool_size + return full + + def empty(self): + self.items = [] + + def update(self, items): + for item in items: + if len(self.items) < self.pool_size: + # the pool is not full, so let's add this in + self.items.append(item) + else: + # the pool is full + # pop from the front + self.items.pop(0) + # add to the back + self.items.append(item) + return self.items + +def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): + """ + Input: + xyz: pointcloud data, [B, N, C], where C is probably 3 + npoint: number of samples + Return: + inds: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + B, N, C = xyz.shape + xyz = xyz.float() + inds = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + if deterministic: + farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device) + else: + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + if include_ends: + if i==0: + farthest = 0 + elif i==1: + farthest = N-1 + inds[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, C) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + + if npoint > N: + # if we need more samples, make them random + distance += torch.randn_like(distance) + return inds + +def farthest_point_sample_py(xyz, npoint): + N,C = xyz.shape + inds = np.zeros(npoint, dtype=np.int32) + distance = np.ones(N) * 1e10 + farthest = np.random.randint(0, N, dtype=np.int32) + for i in range(npoint): + inds[i] = farthest + centroid = xyz[farthest, :].reshape(1,C) + dist = np.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = np.argmax(distance, -1) + if npoint > N: + # if we need more samples, make them random + distance += np.random.randn(*distance.shape) + return inds \ No newline at end of file diff --git a/monst3r/dust3r/utils/viz_demo.py b/monst3r/dust3r/utils/viz_demo.py new file mode 100644 index 0000000..3102a34 --- /dev/null +++ b/monst3r/dust3r/utils/viz_demo.py @@ -0,0 +1,125 @@ +from mistune import BaseRenderer +from scipy.spatial.transform import Rotation +import numpy as np +import trimesh +from dust3r.utils.device import to_numpy +import torch +import os +import cv2 +from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes +from third_party.raft import load_RAFT +from datasets_preprocess.sintel_get_dynamics import compute_optical_flow +from dust3r.utils.flow_vis import flow_to_image + +def convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, show_cam=True, + cam_color=None, as_pointcloud=False, + transparent_cams=False, silent=False, save_name=None): + assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) + pts3d = to_numpy(pts3d) + imgs = to_numpy(imgs) + focals = to_numpy(focals) + cams2world = to_numpy(cams2world) + + scene = trimesh.Scene() + + # full pointcloud + if as_pointcloud: + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) + pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) + scene.add_geometry(pct) + else: + meshes = [] + for i in range(len(imgs)): + meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i])) + mesh = trimesh.Trimesh(**cat_meshes(meshes)) + scene.add_geometry(mesh) + + # add each camera + if show_cam: + for i, pose_c2w in enumerate(cams2world): + if isinstance(cam_color, list): + camera_edge_color = cam_color[i] + else: + camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] + add_scene_cam(scene, pose_c2w, camera_edge_color, + None if transparent_cams else imgs[i], focals[i], + imsize=imgs[i].shape[1::-1], screen_width=cam_size) + + rot = np.eye(4) + rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() + scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) + if save_name is None: save_name='scene' + outfile = os.path.join(outdir, save_name+'.glb') + if not silent: + print('(exporting 3D scene to', outfile, ')') + scene.export(file_obj=outfile) + return outfile + +def get_dynamic_mask_from_pairviewer(scene, flow_net=None, both_directions=False, output_dir='./demo_tmp', motion_mask_thre=0.35): + """ + get the dynamic mask from the pairviewer + """ + if flow_net is None: + # flow_net = load_RAFT(model_path="third_party/RAFT/models/Tartan-C-T-TSKH-spring540x960-M.pth").to('cuda').eval() # sea-raft + flow_net = load_RAFT(model_path="third_party/RAFT/models/raft-things.pth").to('cuda').eval() + + imgs = scene.imgs + img1 = torch.from_numpy(imgs[0]*255).permute(2,0,1)[None] # (B, 3, H, W) + img2 = torch.from_numpy(imgs[1]*255).permute(2,0,1)[None] + with torch.no_grad(): + forward_flow = flow_net(img1.cuda(), img2.cuda(), iters=20, test_mode=True)[1] # (B, 2, H, W) + if both_directions: + backward_flow = flow_net(img2.cuda(), img1.cuda(), iters=20, test_mode=True)[1] + + B, _, H, W = forward_flow.shape + + depth_map1 = scene.get_depthmaps()[0] # (H, W) + depth_map2 = scene.get_depthmaps()[1] + + im_poses = scene.get_im_poses() + cam1 = im_poses[0] # (4, 4) cam2world + cam2 = im_poses[1] + extrinsics1 = torch.linalg.inv(cam1) # (4, 4) world2cam + extrinsics2 = torch.linalg.inv(cam2) + + intrinsics = scene.get_intrinsics() + intrinsics_1 = intrinsics[0] # (3, 3) + intrinsics_2 = intrinsics[1] + + ego_flow_1_2 = compute_optical_flow(depth_map1, depth_map2, extrinsics1, extrinsics2, intrinsics_1, intrinsics_2) # (H*W, 2) + ego_flow_1_2 = ego_flow_1_2.reshape(H, W, 2).transpose(2, 0, 1) # (2, H, W) + + error_map = np.linalg.norm(ego_flow_1_2 - forward_flow[0].cpu().numpy(), axis=0) # (H, W) + + error_map_normalized = (error_map - error_map.min()) / (error_map.max() - error_map.min()) + error_map_normalized_int = (error_map_normalized * 255).astype(np.uint8) + if both_directions: + ego_flow_2_1 = compute_optical_flow(depth_map2, depth_map1, extrinsics2, extrinsics1, intrinsics_2, intrinsics_1) + ego_flow_2_1 = ego_flow_2_1.reshape(H, W, 2).transpose(2, 0, 1) + error_map_2 = np.linalg.norm(ego_flow_2_1 - backward_flow[0].cpu().numpy(), axis=0) + error_map_2_normalized = (error_map_2 - error_map_2.min()) / (error_map_2.max() - error_map_2.min()) + error_map_2_normalized = (error_map_2_normalized * 255).astype(np.uint8) + cv2.imwrite(f'{output_dir}/dynamic_mask_bw.png', cv2.applyColorMap(error_map_2_normalized, cv2.COLORMAP_JET)) + np.save(f'{output_dir}/dynamic_mask_bw.npy', error_map_2) + + backward_flow = backward_flow[0].cpu().numpy().transpose(1, 2, 0) + np.save(f'{output_dir}/backward_flow.npy', backward_flow) + flow_img = flow_to_image(backward_flow) + cv2.imwrite(f'{output_dir}/backward_flow.png', flow_img) + + cv2.imwrite(f'{output_dir}/dynamic_mask.png', cv2.applyColorMap(error_map_normalized_int, cv2.COLORMAP_JET)) + error_map_normalized_bin = (error_map_normalized > motion_mask_thre).astype(np.uint8) + # save the binary mask + cv2.imwrite(f'{output_dir}/dynamic_mask_binary.png', error_map_normalized_bin*255) + # save the original one as npy file + np.save(f'{output_dir}/dynamic_mask.npy', error_map) + + # also save the flow + forward_flow = forward_flow[0].cpu().numpy().transpose(1, 2, 0) + np.save(f'{output_dir}/forward_flow.npy', forward_flow) + # save flow as image + flow_img = flow_to_image(forward_flow) + cv2.imwrite(f'{output_dir}/forward_flow.png', flow_img) + + return error_map \ No newline at end of file diff --git a/monst3r/dust3r/utils/vo_eval.py b/monst3r/dust3r/utils/vo_eval.py new file mode 100644 index 0000000..8f95a8e --- /dev/null +++ b/monst3r/dust3r/utils/vo_eval.py @@ -0,0 +1,334 @@ +import os +import re +from copy import deepcopy +from pathlib import Path + +import evo.main_ape as main_ape +import evo.main_rpe as main_rpe +import matplotlib.pyplot as plt +import numpy as np +from evo.core import sync +from evo.core.metrics import PoseRelation, Unit +from evo.core.trajectory import PosePath3D, PoseTrajectory3D +from evo.tools import file_interface, plot +from scipy.spatial.transform import Rotation + + +def sintel_cam_read(filename): + """Read camera data, return (M,N) tuple. + + M is the intrinsic matrix, N is the extrinsic matrix, so that + + x = M*N*X, + with x being a point in homogeneous image pixel coordinates, X being a + point in homogeneous world coordinates. + """ + TAG_FLOAT = 202021.25 + + f = open(filename, "rb") + check = np.fromfile(f, dtype=np.float32, count=1)[0] + assert ( + check == TAG_FLOAT + ), " cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format( + TAG_FLOAT, check + ) + M = np.fromfile(f, dtype="float64", count=9).reshape((3, 3)) + N = np.fromfile(f, dtype="float64", count=12).reshape((3, 4)) + return M, N + + +def load_replica_traj(gt_file): + traj_w_c = np.loadtxt(gt_file) + assert traj_w_c.shape[1] == 12 or traj_w_c.shape[1] == 16 + poses = [ + np.array( + [ + [r[0], r[1], r[2], r[3]], + [r[4], r[5], r[6], r[7]], + [r[8], r[9], r[10], r[11]], + [0, 0, 0, 1], + ] + ) + for r in traj_w_c + ] + + pose_path = PosePath3D(poses_se3=poses) + timestamps_mat = np.arange(traj_w_c.shape[0]).astype(float) + + traj = PoseTrajectory3D(poses_se3=pose_path.poses_se3, timestamps=timestamps_mat) + xyz = traj.positions_xyz + # shift -1 column -> w in back column + # quat = np.roll(traj.orientations_quat_wxyz, -1, axis=1) + # uncomment this line if the quaternion is in scalar-first format + quat = traj.orientations_quat_wxyz + + traj_tum = np.column_stack((xyz, quat)) + return (traj_tum, timestamps_mat) + + +def load_sintel_traj(gt_file): # './data/sintel/training/camdata_left/alley_2' + # Refer to ParticleSfM + gt_pose_lists = sorted(os.listdir(gt_file)) + gt_pose_lists = [os.path.join(gt_file, x) for x in gt_pose_lists if x.endswith(".cam")] + tstamps = [float(x.split("/")[-1][:-4].split("_")[-1]) for x in gt_pose_lists] + gt_poses = [sintel_cam_read(f)[1] for f in gt_pose_lists] # [1] means get the extrinsic + xyzs, wxyzs = [], [] + tum_gt_poses = [] + for gt_pose in gt_poses: + gt_pose = np.concatenate([gt_pose, np.array([[0, 0, 0, 1]])], 0) + gt_pose_inv = np.linalg.inv(gt_pose) # world2cam -> cam2world + xyz = gt_pose_inv[:3, -1] + xyzs.append(xyz) + R = Rotation.from_matrix(gt_pose_inv[:3, :3]) + xyzw = R.as_quat() # scalar-last for scipy + wxyz = np.array([xyzw[-1], xyzw[0], xyzw[1], xyzw[2]]) + wxyzs.append(wxyz) + tum_gt_pose = np.concatenate([xyz, wxyz], 0) #TODO: check if this is correct + tum_gt_poses.append(tum_gt_pose) + + tum_gt_poses = np.stack(tum_gt_poses, 0) + tum_gt_poses[:, :3] = tum_gt_poses[:, :3] - np.mean( + tum_gt_poses[:, :3], 0, keepdims=True + ) + tt = np.expand_dims(np.stack(tstamps, 0), -1) + return tum_gt_poses, tt + + +def load_traj(gt_traj_file, traj_format="sintel", skip=0, stride=1, num_frames=None): + """Read trajectory format. Return in TUM-RGBD format. + Returns: + traj_tum (N, 7): camera to world poses in (x,y,z,qx,qy,qz,qw) + timestamps_mat (N, 1): timestamps + """ + if traj_format == "replica": + traj_tum, timestamps_mat = load_replica_traj(gt_traj_file) + elif traj_format == "sintel": + traj_tum, timestamps_mat = load_sintel_traj(gt_traj_file) + elif traj_format in ["tum", "tartanair"]: + traj = file_interface.read_tum_trajectory_file(gt_traj_file) + xyz = traj.positions_xyz + quat = traj.orientations_quat_wxyz + timestamps_mat = traj.timestamps + traj_tum = np.column_stack((xyz, quat)) + else: + raise NotImplementedError + + traj_tum = traj_tum[skip::stride] + timestamps_mat = timestamps_mat[skip::stride] + if num_frames is not None: + traj_tum = traj_tum[:num_frames] + timestamps_mat = timestamps_mat[:num_frames] + return traj_tum, timestamps_mat + + +def update_timestamps(gt_file, traj_format, skip=0, stride=1): + """Update timestamps given a""" + if traj_format == "tum": + traj_t_map_file = gt_file.replace("groundtruth.txt", "rgb.txt") + timestamps = load_timestamps(traj_t_map_file, traj_format) + return timestamps[skip::stride] + elif traj_format == "tartanair": + traj_t_map_file = gt_file.replace("gt_pose.txt", "times.txt") + timestamps = load_timestamps(traj_t_map_file, traj_format) + return timestamps[skip::stride] + + +def load_timestamps(time_file, traj_format="replica"): + if traj_format in ["tum", "tartanair"]: + with open(time_file, "r+") as f: + lines = f.readlines() + timestamps_mat = [ + float(x.split(" ")[0]) for x in lines if not x.startswith("#") + ] + return timestamps_mat + + +def make_traj(args) -> PoseTrajectory3D: + if isinstance(args, tuple) or isinstance(args, list): + traj, tstamps = args + return PoseTrajectory3D( + positions_xyz=traj[:, :3], + orientations_quat_wxyz=traj[:, 3:], + timestamps=tstamps, + ) + assert isinstance(args, PoseTrajectory3D), type(args) + return deepcopy(args) + + +def eval_metrics(pred_traj, gt_traj=None, seq="", filename="", sample_stride=1): + + if sample_stride > 1: + pred_traj[0] = pred_traj[0][::sample_stride] + pred_traj[1] = pred_traj[1][::sample_stride] + if gt_traj is not None: + updated_gt_traj = [] + updated_gt_traj.append(gt_traj[0][::sample_stride]) + updated_gt_traj.append(gt_traj[1][::sample_stride]) + gt_traj = updated_gt_traj + + pred_traj = make_traj(pred_traj) + + if gt_traj is not None: + gt_traj = make_traj(gt_traj) + + if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]: + pred_traj.timestamps = gt_traj.timestamps + else: + print(pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0]) + + gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj) + + # ATE + traj_ref = gt_traj + traj_est = pred_traj + + ate_result = main_ape.ape( + traj_ref, + traj_est, + est_name="traj", + pose_relation=PoseRelation.translation_part, + align=True, + correct_scale=True, + ) + + ate = ate_result.stats["rmse"] + + # RPE rotation and translation + delta_list = [1] + rpe_rots, rpe_transs = [], [] + for delta in delta_list: + rpe_rots_result = main_rpe.rpe( + traj_ref, + traj_est, + est_name="traj", + pose_relation=PoseRelation.rotation_angle_deg, + align=True, + correct_scale=True, + delta=delta, + delta_unit=Unit.frames, + rel_delta_tol=0.01, + all_pairs=True, + ) + + rot = rpe_rots_result.stats["rmse"] + rpe_rots.append(rot) + + for delta in delta_list: + rpe_transs_result = main_rpe.rpe( + traj_ref, + traj_est, + est_name="traj", + pose_relation=PoseRelation.translation_part, + align=True, + correct_scale=True, + delta=delta, + delta_unit=Unit.frames, + rel_delta_tol=0.01, + all_pairs=True, + ) + + trans = rpe_transs_result.stats["rmse"] + rpe_transs.append(trans) + + rpe_trans, rpe_rot = np.mean(rpe_transs), np.mean(rpe_rots) + with open(filename, "w+") as f: + f.write(f"Seq: {seq} \n\n") + f.write(f"{ate_result}") + f.write(f"{rpe_rots_result}") + f.write(f"{rpe_transs_result}") + + print(f"Save results to {filename}") + return ate, rpe_trans, rpe_rot + + +def best_plotmode(traj): + _, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0)) + plot_axes = "xyz"[i2] + "xyz"[i1] + return getattr(plot.PlotMode, plot_axes) + + +def plot_trajectory( + pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True +): + pred_traj = make_traj(pred_traj) + + if gt_traj is not None: + gt_traj = make_traj(gt_traj) + if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]: + pred_traj.timestamps = gt_traj.timestamps + else: + print("WARNING", pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0]) + + gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj) + + if align: + pred_traj.align(gt_traj, correct_scale=correct_scale) + + plot_collection = plot.PlotCollection("PlotCol") + fig = plt.figure(figsize=(8, 8)) + plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj) + ax = plot.prepare_axis(fig, plot_mode) + ax.set_title(title) + if gt_traj is not None: + plot.traj(ax, plot_mode, gt_traj, "--", "gray", "Ground Truth") + plot.traj(ax, plot_mode, pred_traj, "-", "blue", "Predicted") + plot_collection.add_figure("traj_error", fig) + plot_collection.export(filename, confirm_overwrite=False) + plt.close(fig=fig) + print(f"Saved trajectory to {filename.replace('.png','')}_traj_error.png") + + +def save_trajectory_tum_format(traj, filename): + traj = make_traj(traj) + tostr = lambda a: " ".join(map(str, a)) + with Path(filename).open("w") as f: + for i in range(traj.num_poses): + f.write( + f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[0,1,2,3]])}\n" + ) + print(f"Saved trajectory to {filename}") + + +def extract_metrics(file_path): + with open(file_path, 'r') as file: + content = file.read() + + # Extract metrics using regex + ate_match = re.search(r'APE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)', content, re.DOTALL) + rpe_trans_match = re.search(r'RPE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)', content, re.DOTALL) + rpe_rot_match = re.search(r'RPE w.r.t. rotation angle in degrees \(deg\).*?rmse\s+([0-9.]+)', content, re.DOTALL) + + ate = float(ate_match.group(1)) if ate_match else 0.0 + rpe_trans = float(rpe_trans_match.group(1)) if rpe_trans_match else 0.0 + rpe_rot = float(rpe_rot_match.group(1)) if rpe_rot_match else 0.0 + + return ate, rpe_trans, rpe_rot + +def process_directory(directory): + results = [] + for root, _, files in os.walk(directory): + if files is not None: + files = sorted(files) + for file in files: + if file.endswith('_metric.txt'): + file_path = os.path.join(root, file) + seq_name = file.replace('_eval_metric.txt', '') + ate, rpe_trans, rpe_rot = extract_metrics(file_path) + results.append((seq_name, ate, rpe_trans, rpe_rot)) + + return results + +def calculate_averages(results): + total_ate = sum(r[1] for r in results) + total_rpe_trans = sum(r[2] for r in results) + total_rpe_rot = sum(r[3] for r in results) + count = len(results) + + if count == 0: + return 0.0, 0.0, 0.0 + + avg_ate = total_ate / count + avg_rpe_trans = total_rpe_trans / count + avg_rpe_rot = total_rpe_rot / count + + return avg_ate, avg_rpe_trans, avg_rpe_rot diff --git a/monst3r/dust3r/viz.py b/monst3r/dust3r/viz.py new file mode 100644 index 0000000..6fafbca --- /dev/null +++ b/monst3r/dust3r/viz.py @@ -0,0 +1,393 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Visualization utilities using trimesh +# -------------------------------------------------------- +import PIL.Image +import numpy as np +from scipy.spatial.transform import Rotation +import torch +import io +from PIL import Image + +from dust3r.utils.geometry import geotrf, get_med_dist_between_poses, depthmap_to_absolute_camera_coordinates, depthmap_to_pts3d +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb, img_to_arr + +try: + import trimesh +except ImportError: + print('/!\\ module trimesh is not installed, cannot visualize results /!\\') + + + +def cat_3d(vecs): + if isinstance(vecs, (np.ndarray, torch.Tensor)): + vecs = [vecs] + return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)]) + + +def show_raw_pointcloud(pts3d, colors, point_size=2): + scene = trimesh.Scene() + + pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors)) + scene.add_geometry(pct) + + scene.show(line_settings={'point_size': point_size}) + + +def pts3d_to_trimesh(img, pts3d, valid=None): + H, W, THREE = img.shape + assert THREE == 3 + assert img.shape == pts3d.shape + + vertices = pts3d.reshape(-1, 3) + + # make squares: each pixel == 2 triangles + idx = np.arange(len(vertices)).reshape(H, W) + idx1 = idx[:-1, :-1].ravel() # top-left corner + idx2 = idx[:-1, +1:].ravel() # right-left corner + idx3 = idx[+1:, :-1].ravel() # bottom-left corner + idx4 = idx[+1:, +1:].ravel() # bottom-right corner + faces = np.concatenate(( + np.c_[idx1, idx2, idx3], + np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling) + np.c_[idx2, idx3, idx4], + np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling) + ), axis=0) + + # prepare triangle colors + face_colors = np.concatenate(( + img[:-1, :-1].reshape(-1, 3), + img[:-1, :-1].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3) + ), axis=0) + + # remove invalid faces + if valid is not None: + assert valid.shape == (H, W) + valid_idxs = valid.ravel() + valid_faces = valid_idxs[faces].all(axis=-1) + faces = faces[valid_faces] + face_colors = face_colors[valid_faces] + + assert len(faces) == len(face_colors) + return dict(vertices=vertices, face_colors=face_colors, faces=faces) + + +def cat_meshes(meshes): + vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes]) + n_vertices = np.cumsum([0]+[len(v) for v in vertices]) + for i in range(len(faces)): + faces[i][:] += n_vertices[i] + + vertices = np.concatenate(vertices) + colors = np.concatenate(colors) + faces = np.concatenate(faces) + return dict(vertices=vertices, face_colors=colors, faces=faces) + + +def show_duster_pairs(view1, view2, pred1, pred2): + import matplotlib.pyplot as pl + pl.ion() + + for e in range(len(view1['instance'])): + i = view1['idx'][e] + j = view2['idx'][e] + img1 = rgb(view1['img'][e]) + img2 = rgb(view2['img'][e]) + conf1 = pred1['conf'][e].squeeze() + conf2 = pred2['conf'][e].squeeze() + score = conf1.mean()*conf2.mean() + print(f">> Showing pair #{e} {i}-{j} {score=:g}") + pl.clf() + pl.subplot(221).imshow(img1) + pl.subplot(223).imshow(img2) + pl.subplot(222).imshow(conf1, vmin=1, vmax=30) + pl.subplot(224).imshow(conf2, vmin=1, vmax=30) + pts1 = pred1['pts3d'][e] + pts2 = pred2['pts3d_in_other_view'][e] + pl.subplots_adjust(0, 0, 1, 1, 0, 0) + if input('show pointcloud? (y/n) ') == 'y': + show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5) + + +def auto_cam_size(im_poses): + return 0.1 * get_med_dist_between_poses(im_poses) + + +class SceneViz: + def __init__(self): + self.scene = trimesh.Scene() + + def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None): + image = img_to_arr(image) + + # make up some intrinsics + if intrinsics is None: + H, W, THREE = image.shape + focal = max(H, W) + intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]]) + + # compute 3d points + pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world) + + return self.add_pointcloud(pts3d, image, mask=(depth 150) + mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) + mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) + + # Morphological operations + kernel = np.ones((5, 5), np.uint8) + mask2 = ndimage.binary_opening(mask, structure=kernel) + + # keep only largest CC + _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8) + cc_sizes = stats[1:, cv2.CC_STAT_AREA] + order = cc_sizes.argsort()[::-1] # bigger first + i = 0 + selection = [] + while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: + selection.append(1 + order[i]) + i += 1 + mask3 = np.in1d(labels, selection).reshape(labels.shape) + + # Apply mask + return torch.from_numpy(mask3) diff --git a/monst3r/install_render.sh b/monst3r/install_render.sh new file mode 100644 index 0000000..669d799 --- /dev/null +++ b/monst3r/install_render.sh @@ -0,0 +1,10 @@ +# pip install pyrender pyglet +# apt-get update +# apt-get install -y libegl1-mesa-dev libgles2-mesa-dev +# apt-get install libglu1-mesa-dev freeglut3-dev + +cd /lpai/volumes/ad-vla-vol-ga/chenantong/codes/monst3r + +apt-get update && apt-get install tmux -y + +pip install -r requirements.txt \ No newline at end of file diff --git a/monst3r/launch.py b/monst3r/launch.py new file mode 100644 index 0000000..65bd135 --- /dev/null +++ b/monst3r/launch.py @@ -0,0 +1,38 @@ +# -------------------------------------------------------- +# training executable for DUSt3R +# -------------------------------------------------------- +from dust3r.training import get_args_parser, train, load_model +from dust3r.pose_eval import eval_pose_estimation +from dust3r.depth_eval import eval_mono_depth_estimation +import croco.utils.misc as misc # noqa +import torch +import torch.backends.cudnn as cudnn +import numpy as np +import os + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + if args.mode.startswith('eval'): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = args.cudnn_benchmark + model, _ = load_model(args, device) + os.makedirs(args.output_dir, exist_ok=True) + + if args.mode == 'eval_pose': + ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation(args, model, device, save_dir=args.output_dir) + print(f'ATE mean: {ate_mean}, RPE trans mean: {rpe_trans_mean}, RPE rot mean: {rpe_rot_mean}') + if args.mode == 'eval_depth': + eval_mono_depth_estimation(args, model, device) + + exit(0) + train(args) diff --git a/monst3r/prepare_data_edit.py b/monst3r/prepare_data_edit.py new file mode 100755 index 0000000..1fcb850 --- /dev/null +++ b/monst3r/prepare_data_edit.py @@ -0,0 +1,674 @@ +import sys +import time +sys.path.append('./extern/dust3r') +from dust3r.inference import inference, load_model +from dust3r.utils.image import load_images +from dust3r.image_pairs import make_pairs +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.utils.device import to_numpy +import trimesh +import torch +import numpy as np +import torchvision +import os +import copy +import cv2 +import glob +from PIL import Image +import pytorch3d +from pytorch3d.structures import Pointclouds +from torchvision.utils import save_image +import torch.nn.functional as F +import torchvision.transforms as transforms +from PIL import Image +from utils.pvd_utils import * +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from utils.diffusion_utils import instantiate_from_config,load_model_checkpoint,image_guided_synthesis +from pathlib import Path +from torchvision.utils import save_image +import json +import random + +from configs.infer_config import get_parser +from utils.pvd_utils import * +from datetime import datetime +from tqdm import tqdm + +global json_path, data_root + +json_path = None +data_root = None + +def process_instance_sequence(frames, corners_dict, img_size=(1600, 900), drop_crop=False, device='cuda'): + """ + Process each instance: crop from first frame and replant in subsequent frames + + Args: + frames: tensor of shape (25, H, W, 3) or (25, 576, 1024, 3) + corners_dict: dict of {instance_id: corners array (25, 8, 2)} in 1600x900 scale + img_size: tuple of original corner coordinates scale + Returns: + processed_frames: frames with replanted instances + """ + # Convert frames to GPU tensor if needed + if not torch.is_tensor(frames): + frames = torch.from_numpy(frames).to(device) + elif frames.device.type != 'cuda': + frames = frames.to(device) + + processed_frames = frames.clone() + num_frames, H, W = frames.shape[:3] + + # Calculate scaling factors + scale_x = W / img_size[0] + scale_y = H / img_size[1] + + for instance_id, corners in corners_dict.items(): + # Get first frame corners and move to GPU + first_frame_corners = corners[0][0].clone().T # (8, 2) + first_frame_corners[:, 0] *= scale_x + first_frame_corners[:, 1] *= scale_y + + # Convert 8 corners to 4-point bbox + first_bbox = corners_to_bbox(first_frame_corners, device=device) + + # Crop the instance from first frame + instance_crop = crop_bbox(frames[0], first_bbox, device=device) + + if drop_crop: + # Generate drop percentages on GPU + drop_percentages = torch.linspace(0, 0.5, num_frames-1, device=device) + else: + drop_percentages = torch.zeros(num_frames-1, device=device) + + # Replant in each frame + for frame_idx in range(1, num_frames): + if corners[frame_idx] is None: + continue + current_corners = corners[frame_idx][0].clone().T + current_corners[:, 0] *= scale_x + current_corners[:, 1] *= scale_y + current_bbox = corners_to_bbox(current_corners, device=device) + + # Paste the cropped instance + processed_frames[frame_idx] = paste_cropped_area( + processed_frames[frame_idx], + instance_crop, + current_bbox, + drop_percentage=drop_percentages[frame_idx-1].item(), + device=device + ) + + return processed_frames + +def corners_to_bbox(corners, device='cuda'): + """Convert 8x2 corners to 4x2 bbox""" + # Convert numpy array to tensor if needed and move to GPU + if not torch.is_tensor(corners): + corners = torch.from_numpy(corners).to(device) + elif corners.device.type != 'cuda': + corners = corners.to(device) + + x_min = torch.min(corners[:, 0]) + x_max = torch.max(corners[:, 0]) + y_min = torch.min(corners[:, 1]) + y_max = torch.max(corners[:, 1]) + + return torch.tensor([ + [x_min, y_min], # top-left + [x_max, y_min], # top-right + [x_max, y_max], # bottom-right + [x_min, y_max] # bottom-left + ], device=device) + +def crop_bbox(image, bbox, device='cuda'): + """Crop image region defined by bbox""" + if not torch.is_tensor(image): + image = torch.from_numpy(image).to(device) + elif image.device.type != 'cuda': + image = image.to(device) + + if not torch.is_tensor(bbox): + bbox = torch.from_numpy(bbox).to(device) + elif bbox.device.type != 'cuda': + bbox = bbox.to(device) + + x_min = int(torch.min(bbox[:, 0]).item()) + x_max = int(torch.max(bbox[:, 0]).item()) + y_min = int(torch.min(bbox[:, 1]).item()) + y_max = int(torch.max(bbox[:, 1]).item()) + + # Ensure coordinates are within image bounds + x_min = max(0, x_min) + y_min = max(0, y_min) + x_max = min(image.shape[1], x_max) + y_max = min(image.shape[0], y_max) + + return image[y_min:y_max, x_min:x_max].clone() + +def paste_cropped_area(target_image, cropped_area, target_bbox, drop_percentage=0., device='cuda'): + """Resize and paste cropped area with smooth blending""" + # Move all inputs to GPU + if not torch.is_tensor(target_image): + target_image = torch.from_numpy(target_image).to(device) + elif target_image.device.type != 'cuda': + target_image = target_image.to(device) + + if not torch.is_tensor(cropped_area): + cropped_area = torch.from_numpy(cropped_area).to(device) + elif cropped_area.device.type != 'cuda': + cropped_area = cropped_area.to(device) + + if not torch.is_tensor(target_bbox): + target_bbox = torch.from_numpy(target_bbox).to(device) + elif target_bbox.device.type != 'cuda': + target_bbox = target_bbox.to(device) + + # Get target dimensions + original_x_min = target_x_min = int(torch.min(target_bbox[:, 0]).item()) + original_x_max = target_x_max = int(torch.max(target_bbox[:, 0]).item()) + original_y_min = target_y_min = int(torch.min(target_bbox[:, 1]).item()) + original_y_max = target_y_max = int(torch.max(target_bbox[:, 1]).item()) + original_width = original_x_max - original_x_min + original_height = original_y_max - original_y_min + + # Ensure coordinates are within bounds + target_x_min = max(0, target_x_min) + target_y_min = max(0, target_y_min) + target_x_max = min(target_image.shape[1], target_x_max) + target_y_max = min(target_image.shape[0], target_y_max) + + # Calculate target size + target_width = target_x_max - target_x_min + target_height = target_y_max - target_y_min + + if target_width <= 0 or target_height <= 0: + return target_image + + # Resize cropped area using torch interpolate + if (target_width, target_height) != (original_width, original_height): + resized_crop = F.interpolate( + cropped_area.permute(2, 0, 1).unsqueeze(0), + size=(original_height, original_width), + mode='bilinear', + align_corners=False + ).squeeze(0).permute(1, 2, 0) + + # Determine valid region + crop_x_start = max(0, target_x_min - original_x_min) + crop_x_end = crop_x_start + target_width + crop_y_start = max(0, target_y_min - original_y_min) + crop_y_end = crop_y_start + target_height + + resized_crop = resized_crop[crop_y_start:crop_y_end, crop_x_start:crop_x_end] + else: + resized_crop = F.interpolate( + cropped_area.permute(2, 0, 1).unsqueeze(0), + size=(target_height, target_width), + mode='bilinear', + align_corners=False + ).squeeze(0).permute(1, 2, 0) + + # Random pixel dropping using torch + if drop_percentage > 0: + mask_drop = torch.rand(resized_crop.shape[:2], device=device) > drop_percentage + if resized_crop.ndim == 3: + mask_drop = mask_drop.unsqueeze(-1).expand(-1, -1, resized_crop.shape[2]) + resized_crop = resized_crop * mask_drop.float() + + # Create result image and mask + result_image = target_image.clone() + mask = torch.ones(resized_crop.shape[:2], device=device) + + # Apply Gaussian blur to mask + mask = mask.unsqueeze(0).unsqueeze(0) + gaussian_kernel = torch.tensor([ + [1, 4, 6, 4, 1], + [4, 16, 24, 16, 4], + [6, 24, 36, 24, 6], + [4, 16, 24, 16, 4], + [1, 4, 6, 4, 1] + ], device=device).float() / 256.0 + gaussian_kernel = gaussian_kernel.view(1, 1, 5, 5) + mask = F.conv2d(mask, gaussian_kernel, padding=2) + mask = mask.squeeze() + + try: + # Blend images + roi = result_image[target_y_min:target_y_max, target_x_min:target_x_max] + mask = mask.unsqueeze(-1).expand(-1, -1, 3) + blended = (resized_crop * mask + roi * (1 - mask)) + + result_image[target_y_min:target_y_max, target_x_min:target_x_max] = blended + except Exception as e: + print(f"Error in blending paste_cropped_area: {e}") + print(f"target_y_min: {target_y_min}, target_y_max: {target_y_max}, target_x_min: {target_x_min}, target_x_max: {target_x_max}") + + return result_image + +def resize_and_reshape_image(image_tensor): + """ + Resize and reshape image tensor from (900, 1600, 3) to (1, 576, 1024, 3) + + Args: + image_tensor: tensor of shape (H, W, 3) + Returns: + resized_tensor: tensor of shape (1, 576, 1024, 3) + """ + # 1. Add batch and permute to (1, 3, H, W) for F.interpolate + x = image_tensor.permute(2, 0, 1).unsqueeze(0) # (1, 3, 900, 1600) + + # 2. Resize the image + resized = F.interpolate( + x, + size=(576, 1024), + mode='bilinear', + align_corners=False + ) + + # 3. Permute back to (1, H, W, 3) + reshaped = resized.permute(0, 2, 3, 1) + + return reshaped + +class ViewCrafter_DataEngine: + def __init__(self, opts, gradio = False): + self.opts = opts + self.device = opts.device + self.setup_dust3r() + + def run_dust3r(self, input_images,clean_pc = False): + pairs = make_pairs(input_images, scene_graph='complete', prefilter=None, symmetrize=True) + output = inference(pairs, self.dust3r, self.device, batch_size=self.opts.batch_size) + + mode = GlobalAlignerMode.PointCloudOptimizer #if len(self.images) > 2 else GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=self.device, mode=mode) + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment(init='mst', niter=self.opts.niter, schedule=self.opts.schedule, lr=self.opts.lr) + + if clean_pc: + self.scene = scene.clean_pointcloud() + else: + self.scene = scene + + pc = self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)[0].detach() + + def render_pcd(self,pts3d,imgs,masks,views,renderer,device,nbv=False): + + # imgs = to_numpy(imgs) + # pts3d = to_numpy(pts3d) + + if masks is None: + # pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) + # col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) + if isinstance(pts3d, torch.Tensor): + pts = pts3d.to(device) + else: + pts = torch.from_numpy(pts3d).to(device) + if isinstance(imgs, torch.Tensor): + col = imgs.to(device) + else: + col = torch.from_numpy(imgs).to(device) + else: + # masks = to_numpy(masks) + pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) + col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) + + point_cloud = Pointclouds(points=[pts], features=[col]).extend(views).to(self.device) + images = renderer(point_cloud) + + if nbv: + color_mask = torch.ones(col.shape).to(device) + point_cloud_mask = Pointclouds(points=[pts],features=[color_mask]).extend(views) + view_masks = renderer(point_cloud_mask) + else: + view_masks = None + + return images, view_masks + + def run_render(self, pcd, imgs,masks, H, W, camera_traj,num_views,nbv=False): + render_setup = setup_renderer(camera_traj, image_size=(H,W)) + renderer = render_setup['renderer'].to(self.device) + render_results, viewmask = self.render_pcd(pcd, imgs, masks, num_views,renderer,self.device,nbv=False) + return render_results, viewmask + + def generate_traj_specified_single(self, c2ws_anchor,H,W,fs,c,theta=0., phi=0.,d_r=0.,d_x=0.,d_y=0.,frame=1,device='cuda'): + # Initialize a camera. + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + + c2w_new = sphere2pose(c2ws_anchor, np.float32(theta), np.float32(phi), np.float32(d_r), device, np.float32(d_x),np.float32(d_y)) + c2ws = c2w_new + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,num_views + + def nvs_multi_pose_fixed_view_change_objects(self, images, img_ori, target_frames, device, save_dir=None, instance_corners=None, phi=0., theta=0., d_r=0., d_x=0., d_y=0.): + """ + Render multiple camera poses from a fixed viewing transformation + + Args: + scene: Scene object containing camera parameters and point clouds + images: Dictionary of input images + img_ori: Original input images + target_frames: List of target frame indices + sampled_indices: List of sampled frame indices + opts: Options/configuration object + device: Torch device to use + c2ws: Camera poses in world coordinates [N, 4, 4] + intrinsics: Camera intrinsics matrix [3, 3] + save_dir: Directory to save outputs + phi: Fixed view transformation parameter for rotation around y-axis + theta: Fixed view transformation parameter for rotation around x-axis + d_r: Fixed view transformation parameter for radius change + d_x: Fixed view transformation parameter for x translation + d_y: Fixed view transformation parameter for y translation + gradio: Boolean flag for gradio interface + """ + # Get all camera poses and parameters + c2ws = self.scene.get_im_poses().detach() + # Use only first frame's camera intrinsics + principal_points = self.scene.get_principal_points()[0:1].detach() + focals = self.scene.get_focals()[0:1].detach() + + # Get image dimensions + shape = images[0]['true_shape'] + H, W = int(shape[0][0]), int(shape[0][1]) + + print(f'true shape H: {H}, W: {W}') + + # Get point clouds and depth + pcd = [i.detach() for i in self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)] + + + depth = [i.detach() for i in self.scene.get_depthmaps()] + depth_avg = depth[0][H//2,W//2] # Use first frame's depth + radius = depth_avg*self.opts.center_scale + + # Transform coordinates to object-centric + c2ws, pcd = world_point_to_obj( + poses=c2ws, + points=torch.stack(pcd), + k=0, # Use first frame as reference + r=radius, + elevation=self.opts.elevation, + device=device + ) + + imgs = np.array(self.scene.imgs) + + print(f'imgs shape: {imgs.shape}') + masks = None + + # Apply the fixed view transformation to all camera poses + camera_poses = [] + for i in range(len(c2ws)): + # Generate single transformed pose + single_pose, _ = self.generate_traj_specified_single( + c2ws[i:i+1], # Take one pose at a time + H, W, + focals, # Use first frame's focal length for all + principal_points, # Use first frame's principal point for all + theta, phi, d_r, d_x, d_y, + 1, # Only generate one frame per pose + device + ) + camera_poses.append(single_pose) + + # Combine all transformed cameras + camera_traj = camera_poses[0] + for cam in camera_poses[1:]: + camera_traj.R = torch.cat([camera_traj.R, cam.R]) + camera_traj.T = torch.cat([camera_traj.T, cam.T]) + + # Concatenate all point clouds and images + # all_pcd = torch.cat([p.reshape(-1, 3) for p in pcd], dim=0) + # all_imgs = np.concatenate([img.reshape(-1, 3) for img in imgs], axis=0) + + def black_out_instance_cv(point_cloud, corners_dict, img_size=(1600, 900)): + """ + Black out instances in point cloud using scaled corners + + Args: + point_cloud: tensor of shape (H, W, 3) + corners_dict: dict of {instance_id: corners array (25, 8, 2)} in 1600x900 scale + img_size: tuple of (width, height) of original corner coordinates + Returns: + modified_pc: point cloud with blacked out instances + """ + # Ensure point cloud is on the same device + device = point_cloud.device + modified_pc = point_cloud.clone() # Create a copy of the point cloud + H, W = modified_pc.shape[:2] + + # Calculate scaling factors + scale_x = W / img_size[0] + scale_y = H / img_size[1] + + for instance_id, corners in corners_dict.items(): + # Get first frame corners + first_frame_corners = corners[0][0].clone().T # (8, 2) + + first_frame_corners[:, 0] *= scale_x + first_frame_corners[:, 1] *= scale_y + + # Get bbox coordinates and scale them using torch + x_min = int(torch.min(first_frame_corners[:, 0]).item()) + x_max = int(torch.max(first_frame_corners[:, 0]).item()) + y_min = int(torch.min(first_frame_corners[:, 1]).item()) + y_max = int(torch.max(first_frame_corners[:, 1]).item()) + + # Ensure coordinates are within bounds + x_min = max(0, x_min) + x_max = min(W, x_max) + y_min = max(0, y_min) + y_max = min(H, y_max) + + # Black out the bbox region + modified_pc[y_min:y_max, x_min:x_max] = 0 + + return modified_pc + + os.makedirs(save_dir, exist_ok=True) + #### + if len(instance_corners) > 0: + pcd[0] = black_out_instance_cv(pcd[0], instance_corners) + # Render from all transformed viewpoints + render_results, viewmask = self.run_render( + pcd[0].reshape(-1, 3),#all_pcd, + imgs[0].reshape(-1, 3),#all_imgs, + masks, + H, W, + camera_traj, + len(c2ws), + device + ) + + # Resize renders to target resolution + render_results = F.interpolate( + render_results.permute(0,3,1,2), + size=(576, 1024), + mode='bilinear', + align_corners=False + ).permute(0,2,3,1) + + # Save outputs + # save_video(render_results, os.path.join(save_dir, 'original_render.mp4')) + # save_frames(render_results, save_dir, save_names=[f'original_base{target_frames[0]}_idx{idx}' for idx in sampled_indices], folder=None) + # save_pointcloud_with_normals( + # imgs[0].reshape(-1, 3),#all_imgs, + # pcd[0].reshape(-1, 3),#all_pcd, + # msk=None, + # save_path=os.path.join(save_dir,'pcd_all.ply'), + # mask_pc=False, + # reduce_pc=False + # ) + # print(f'saved to {save_dir}') + + + return render_results + + def setup_dust3r(self): + self.dust3r = load_model(self.opts.model_path, self.device) + self.scene = None + + def process_single_scene(self, scene_data, start_idx=None, interval=None): + """ + Process a single scene - extracted from process_dataset for parallel execution + """ + start_time = time.time() + idx, scene, data_root = scene_data + + target_frames = [0] + self.target_frames = target_frames + + frame_files = scene['frames'] + frame_paths = [os.path.join(data_root, p) for p in frame_files] + instance_corners = scene['instance_corners'] + + save_dir = os.path.join(self.opts.save_dir, f'sample_{idx:06d}') + if os.path.exists(save_dir) and len(os.listdir(save_dir)) >= 25: + print(f'skip {save_dir}') + return None, None + + os.makedirs(save_dir, exist_ok=True) + + images = load_images(frame_paths, size=512, force_1024=True) + img_ori = [(images[i]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. for i in range(len(images))] + + self.images = images + self.img_ori = img_ori + self.run_dust3r(input_images=images, clean_pc=True) + # self.nvs_multi_pose_fixed_view(c2ws=None, intrinsics=camera_intrinsics, save_dir=save_dir) # this is for dust3r pose + rendered_results = self.nvs_multi_pose_fixed_view_change_objects(images, img_ori, target_frames, self.device, save_dir, instance_corners=instance_corners) # this is for nuscenes pose + + self.scene = None + torch.cuda.empty_cache() + + firstframe_img_ori = resize_and_reshape_image(img_ori[0]) + rendered_results[0] = firstframe_img_ori + if len(instance_corners) > 0: + print('encountered instance corners') + # return None, None + processed_frames = process_instance_sequence(rendered_results, instance_corners) + else: + processed_frames = rendered_results + + sampled_indices = range(25) + # save_video(processed_frames, os.path.join(save_dir, 'edited_render.mp4')) + save_frames(processed_frames, save_dir, save_names=[f'edited_base{target_frames[0]}_idx{idx}' for idx in sampled_indices], folder=None) + + print(f'data saved in {save_dir}') + # save_video(torch.stack(img_ori), os.path.join(save_dir, 'frames.mp4')) + + # scene.update({'render_frames': [f'sample_{idx}/edited_base{target_frames[0]}_idx{idx}' for idx in sampled_indices]}) + + # # Convert NumPy arrays to lists and handle None sublists + # for key in scene['instance_corners']: + # scene['instance_corners'][key] = [ + # [array.tolist() for array in sublist] if sublist is not None else None + # for sublist in scene['instance_corners'][key] + # ] + + used_time = time.time() - start_time + + + return scene, used_time + + + def process_dataset(self, json_path, data_root, output_dir): + """ + Process entire dataset sequentially and save results + """ + + frame_cam_corners_samples = torch.load('../new_frame_cam_corners_samples.pth') + frame_cam_samples_indices = torch.load('../frame_cam_samples_indices.pth') + + rest_frame_cam_corners_samples = [] + rest_frame_cam_samples_indices = [] + for i in range(len(frame_cam_corners_samples)): + idx = frame_cam_samples_indices[i] + save_dir = os.path.join(self.opts.save_dir, f'sample_{idx:06d}') + if os.path.exists(save_dir) and len(os.listdir(save_dir)) >= 25: + print(f'skip {save_dir}') + continue + rest_frame_cam_corners_samples.append(frame_cam_corners_samples[i]) + rest_frame_cam_samples_indices.append(idx) + + # Define the function to split and extract nth split + def split_and_get_nth(data, n, splits=16): + if n < 0 or n >= splits: + raise ValueError(f"Invalid split index {n}. Must be between 0 and {splits - 1}.") + + # Calculate the size of each split + split_size = len(data) // splits + remainder = len(data) % splits # For uneven splitting + + # Create the splits + splits_indices = [] + start_idx = 0 + for i in range(splits): + extra = 1 if i < remainder else 0 # Add extra for first 'remainder' splits + splits_indices.append((start_idx, start_idx + split_size + extra)) + start_idx += split_size + extra + + start, end = splits_indices[n] + return data[start:end] + + split_corners = split_and_get_nth(rest_frame_cam_corners_samples, self.opts.section, self.opts.total_sections) + split_indices = split_and_get_nth(rest_frame_cam_samples_indices, self.opts.section, self.opts.total_sections) + + # Prepare data for sequential processing + scene_data = [(idx, scene, data_root) for idx, scene in zip(split_indices, split_corners)] + + training_samples = [] + + times = list() + for data in tqdm(scene_data, desc='Processing scenes'): # Limited to 4 scenes as in original + + try: + result, used_time = self.process_single_scene(data) + if result is None: + continue + training_samples.append(result) + times.append(used_time) + avg_time = sum(times) / len(times) + total_time = avg_time * (len(scene_data) - len(times)) + print(f'Split:{self.opts.section}, time used: {times[-1]:.2f}s, estimated left time: {total_time/3600:.2f}h') + except Exception as e: + print(f'Error: {e}') + continue + + # Save metadata + # with open(os.path.join(os.path.dirname(json_path), 'nuScenes_frontview_base0_fullframes_rendered.json'), 'a') as f: + # json.dump(training_samples, f) + +# Example usage: +# render_results = render_dust3r_scene(scene) +# save_video(render_results, 'rendered_views.mp4') + +if __name__=="__main__": + parser = get_parser() # infer config.py + opts = parser.parse_args() + if opts.exp_name == None: + prefix = datetime.now().strftime("%Y%m%d_%H%M") + opts.exp_name = f'{prefix}_{os.path.splitext(os.path.basename(opts.image_dir))[0]}' + opts.save_dir = os.path.join(opts.out_dir,opts.exp_name) + os.makedirs(opts.save_dir,exist_ok=True) + pvd = ViewCrafter_DataEngine(opts) + + json_path = '/lpai/volumes/ad-vla-vol-ga/chenantong/datasets/NuScenes/annotation/nuScenes_frontview_caminex_fullframes.json' + data_root = "/lpai/dataset/nuscenes-imgs/0-1-0/nuscenes" + + pvd.process_dataset(json_path, data_root, opts.save_dir) \ No newline at end of file diff --git a/monst3r/prepare_data_monst3r.py b/monst3r/prepare_data_monst3r.py new file mode 100644 index 0000000..443cd6a --- /dev/null +++ b/monst3r/prepare_data_monst3r.py @@ -0,0 +1,414 @@ +from asyncio import tasks +import sys +import time +import os +from altair import Theta +from matplotlib.pyplot import sca +import torch +import numpy as np +import cv2 +from PIL import Image +import json +from tqdm import tqdm +from datetime import datetime +import argparse +import functools +import pytorch3d +from pytorch3d.structures import Pointclouds +from pytorch3d.renderer import PerspectiveCameras +import torch.nn.functional as F + +# MonST3R imports +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images +from dust3r.utils.device import to_numpy +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.utils.viz_demo import convert_scene_output_to_glb +import copy + +from pvd_utils import * + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, default='checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth') + parser.add_argument('--model_name', type=str, default='Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt') + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--image_size', type=int, default=512) + parser.add_argument('--out_dir', type=str, default='outputs') + parser.add_argument('--exp_name', type=str, default=None) + parser.add_argument('--section', type=int, default=0) + parser.add_argument('--total_sections', type=int, default=16) + parser.add_argument('--niter', type=int, default=300) + parser.add_argument('--schedule', type=str, default='linear') + parser.add_argument('--center_scale', type=float, default=1.0) + parser.add_argument('--dpt_trd', type=float, default=1.) + parser.add_argument('--bg_trd', type=float, default=0.) + parser.add_argument('--theta', type=float, default=0.) + parser.add_argument('--phi', type=float, default=0.) + parser.add_argument('--d_r', type=float, default=0.) + parser.add_argument('--d_x', type=float, default=0.) + parser.add_argument('--d_y', type=float, default=0.) + parser.add_argument('--elevation', type=float, default=5.) + parser.add_argument('--min_conf_thr', type=float, default=1.1) + parser.add_argument('--scenegraph_type', type=str, default='swinstride') + parser.add_argument('--swinsize', type=int, default=5) + parser.add_argument('--refid', type=int, default=0) + return parser + +# Add helper function for camera setup +def setup_renderer(cameras, image_size): + """Setup PyTorch3D renderer""" + from pytorch3d.renderer import ( + PointsRasterizer, + PointsRenderer, + AlphaCompositor, + PointsRasterizationSettings, + ) + + raster_settings = PointsRasterizationSettings( + image_size=image_size, + radius=0.003, + points_per_pixel=10 + ) + + rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + renderer = PointsRenderer( + rasterizer=rasterizer, + compositor=AlphaCompositor() + ) + + return {"renderer": renderer, "cameras": cameras} + +def sphere2pose(c2ws_input, theta, phi, r, device,x=None,y=None): + # c2ws = copy.deepcopy(c2ws_input) # this produces a bug + c2ws = c2ws_input.detach().clone().to(device) + #先沿着世界坐标系z轴方向平移再旋转 + c2ws[:,2,3] += r + if x is not None: + c2ws[:,1,3] += y + if y is not None: + c2ws[:,0,3] += x + + theta = torch.deg2rad(torch.tensor(theta)).to(device) + sin_value_x = torch.sin(theta) + cos_value_x = torch.cos(theta) + rot_mat_x = torch.tensor([[1, 0, 0, 0], + [0, cos_value_x, -sin_value_x, 0], + [0, sin_value_x, cos_value_x, 0], + [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) + + phi = torch.deg2rad(torch.tensor(phi)).to(device) + sin_value_y = torch.sin(phi) + cos_value_y = torch.cos(phi) + rot_mat_y = torch.tensor([[cos_value_y, 0, sin_value_y, 0], + [0, 1, 0, 0], + [-sin_value_y, 0, cos_value_y, 0], + [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) + + c2ws = torch.matmul(rot_mat_x,c2ws) + c2ws = torch.matmul(rot_mat_y,c2ws) + + return c2ws + +class MonST3R_DataEngine: + def __init__(self, opts): + self.opts = opts + self.device = opts.device + self.setup_monst3r() + + def setup_monst3r(self): + """Initialize MonST3R model""" + if os.path.exists(self.opts.model_path): + weights_path = self.opts.model_path + else: + weights_path = self.opts.model_name + + self.model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(self.device) + self.model.eval() + + def render_pcd(self, pts3d, imgs, masks, views, renderer, device, nbv=False): + """Render point cloud from different viewpoints""" + imgs = to_numpy(imgs) + pts3d = to_numpy(pts3d) + + if masks is None: + pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) + col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) + else: + # masks = to_numpy(masks) + pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) + col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) + + point_cloud = Pointclouds(points=[pts], features=[col]).extend(views).to(device) + images = renderer(point_cloud) + + if nbv: + color_mask = torch.ones(col.shape).to(device) + point_cloud_mask = Pointclouds(points=[pts], features=[color_mask]).extend(views) + view_masks = renderer(point_cloud_mask) + else: + view_masks = None + + return images, view_masks + + def run_render(self, pcd, imgs, masks, H, W, camera_traj, num_views, nbv=False): + """Run rendering pipeline""" + render_setup = setup_renderer(camera_traj, image_size=(H, W)) + renderer = render_setup['renderer'].to(self.device) + render_results, viewmask = self.render_pcd(pcd, imgs, masks, num_views, renderer, self.device, nbv=nbv) + return render_results, viewmask + + def generate_traj_specified_single(self, c2ws_anchor, H, W, fs, c, theta=0., phi=0., d_r=0., d_x=0., d_y=0., frame=1, device='cuda'): + """Generate camera trajectory for rendering""" + c2w_new = sphere2pose(c2ws_anchor, np.float32(theta), np.float32(phi), np.float32(d_r), device, np.float32(d_x), np.float32(d_y)) + c2ws = c2w_new + num_views = c2ws.shape[0] + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + # Convert from dust3r to pytorch3d coordinate system + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)), 1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] + + image_size = ((H, W),) + # R_new: (N, 3, 3) + # T_new: (N, 3) + # fs: (1) + # c: (2) + + cameras = PerspectiveCameras( + focal_length=fs, + principal_point=c, + in_ndc=False, + image_size=image_size, + R=R_new, + T=T_new, + device=device + ) + return cameras, num_views + + def process_single_scene(self, scene_data): + """Process a single scene using MonST3R and render results""" + start_time = time.time() + idx, scene, data_root = scene_data + + frame_files = scene['frames'][:5] # test + frame_paths = [os.path.join(data_root, p) for p in frame_files] + instance_corners = scene['instance_corners'] + + save_dir = os.path.join(self.opts.save_dir, f'sample_{idx:06d}') + if os.path.exists(save_dir) and len(os.listdir(save_dir)) >= 25: + print(f'skip {save_dir}') + return None, None + + os.makedirs(save_dir, exist_ok=True) + + self.model.eval() + # Load and process images using MonST3R pipeline + images = load_images(frame_paths, size=self.opts.image_size) + # img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # [576,1024,3] [0,1] + + scenegraph_type = self.opts.scenegraph_type + if len(images) == 1: + images = [images[0], copy.deepcopy(images[0])] + images[1]['idx'] = 1 + if scenegraph_type == "swin" or scenegraph_type == "swinstride" or scenegraph_type == "swin2stride": + scenegraph_type = scenegraph_type + "-" + str(self.opts.swinsize) + "-noncyclic" + elif scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(self.opts.refid) + + # Make pairs for reconstruction + pairs = make_pairs(images, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) + + # Run inference + output = inference(pairs, self.model, self.device, batch_size=self.opts.batch_size) + + # Global alignment + if len(images) > 2: + mode = GlobalAlignerMode.PointCloudOptimizer + scene = global_aligner(output, device=self.device, mode=mode) + else: + mode = GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=self.device, mode=mode) + + # Optimize global alignment + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment( + init='mst', + niter=self.opts.niter, + schedule=self.opts.schedule, + lr=0.01 + ) + + # Clean point cloud + scene = scene.clean_pointcloud() + + # Save reconstruction outputs + save_folder = os.path.join(save_dir, 'reconstruction') + os.makedirs(save_folder, exist_ok=True) + + # Save camera poses and intrinsics + poses = scene.save_tum_poses(f'{save_folder}/pred_traj.txt') + K = scene.save_intrinsics(f'{save_folder}/pred_intrinsics.txt') + + # Save depth maps and point clouds + depth_maps = scene.save_depth_maps(save_folder) + scene.save_dynamic_masks(save_folder) + + c2ws = scene.get_im_poses().detach() + principal_points = scene.get_principal_points().detach() + focals = scene.get_focals().detach() + shape = images[0]['true_shape'] + H, W = int(shape[0][0]), int(shape[0][1]) + pcd = [i.detach() for i in scene.get_pts3d(clip_thred=self.opts.dpt_trd)] # a list of points of size whc + depth = [i.detach() for i in scene.get_depthmaps()] + depth_avg = depth[0][H//2,W//2] #以ref图像中心处的depth(z)为球心旋转 + radius = depth_avg*self.opts.center_scale #缩放调整 + + ## masks for cleaner point cloud + scene.min_conf_thr = float(scene.conf_trf(torch.tensor(self.opts.min_conf_thr))) + masks = scene.get_masks() + depth = scene.get_depthmaps() + bgs_mask = [dpt > self.opts.bg_trd*(torch.max(dpt[40:-40,:])+torch.min(dpt[40:-40,:])) for dpt in depth] + masks_new = [m+mb for m, mb in zip(masks,bgs_mask)] + masks = to_numpy(masks_new) + + ## render, 从c2ws[0]即ref image对应的相机开始 + imgs = np.array(scene.imgs) + + # Transform coordinates to object-centric + c2ws, pcd = world_point_to_obj( + poses=c2ws, + points=torch.stack(pcd), + k=0, # ! + r=radius, + elevation=self.opts.elevation, + device=self.device + ) + + camera_poses = [] + for i in range(len(c2ws)): + # Generate single transformed pose + single_pose, _ = self.generate_traj_specified_single( + c2ws[i:i+1], # Take one pose at a time + H, W, + focals[0:1], # Use first frame's focal length for all + principal_points[0:1], # Use first frame's principal point for all + self.opts.theta, self.opts.phi, self.opts.d_r, self.opts.d_x, self.opts.d_y, + 1, # Only generate one frame per pose + self.device + ) + camera_poses.append(single_pose) + + # Combine all transformed cameras + camera_traj = camera_poses[0] + for cam in camera_poses[1:]: + camera_traj.R = torch.cat([camera_traj.R, cam.R]) + camera_traj.T = torch.cat([camera_traj.T, cam.T]) + + print(len(pcd)) + print(len(imgs)) + print(len(c2ws)) + + import pdb; pdb.set_trace() + + render_results, _ = self.run_render( + [pcd[0]],#all_pcd, + [imgs[0]], #all_imgs, + [masks[0]], #masks + H, W, + camera_traj, + len(c2ws), + self.device + ) + + # Resize renders to target resolution + render_results = F.interpolate( + render_results.permute(0,3,1,2), + size=(576, 1024), + mode='bilinear', + align_corners=False + ).permute(0,2,3,1) + + # render_results[0] = img_ori + + # Save results + save_frames(render_results, save_dir, save_names=[f'render_{idx:06d}_{i:02d}' for i in range(len(render_results))]) + save_video(render_results, os.path.join(save_dir, 'render.mp4')) + + used_time = time.time() - start_time + return scene, used_time + + def process_dataset(self, json_path, data_root): + """Process entire dataset using MonST3R""" + frame_cam_corners_samples = torch.load('../new_frame_cam_corners_samples.pth') + frame_cam_samples_indices = torch.load('../frame_cam_samples_indices.pth') + + # Filter already processed samples + rest_frame_cam_corners_samples = [] + rest_frame_cam_samples_indices = [] + for i in range(len(frame_cam_corners_samples)): + idx = frame_cam_samples_indices[i] + save_dir = os.path.join(self.opts.save_dir, f'sample_{idx:06d}') + # if os.path.exists(save_dir) and len(os.listdir(save_dir)) >= 25: + # print(f'skip {save_dir}') + # continue + rest_frame_cam_corners_samples.append(frame_cam_corners_samples[i]) + rest_frame_cam_samples_indices.append(idx) + + # Get section of data to process + def split_and_get_nth(data, n, splits=16): + if n < 0 or n >= splits: + raise ValueError(f"Invalid split index {n}. Must be between 0 and {splits - 1}.") + split_size = len(data) // splits + remainder = len(data) % splits + splits_indices = [] + start_idx = 0 + for i in range(splits): + extra = 1 if i < remainder else 0 + splits_indices.append((start_idx, start_idx + split_size + extra)) + start_idx += split_size + extra + start, end = splits_indices[n] + return data[start:end] + + split_corners = split_and_get_nth(rest_frame_cam_corners_samples, self.opts.section, self.opts.total_sections) + split_indices = split_and_get_nth(rest_frame_cam_samples_indices, self.opts.section, self.opts.total_sections) + + # Prepare data for processing + scene_data = [(idx, scene, data_root) for idx, scene in zip(split_indices, split_corners)] + + # Process scenes + times = [] + for data in tqdm(scene_data, desc='Processing scenes'): + # try: + result, used_time = self.process_single_scene(data) + if result is not None: + times.append(used_time) + avg_time = sum(times) / len(times) + total_time = avg_time * (len(scene_data) - len(times)) + print(f'Split:{self.opts.section}, time used: {times[-1]:.2f}s, estimated left time: {total_time/3600:.2f}h') + # except Exception as e: + # print(f'Error processing scene: {e}') + # continue + +if __name__ == "__main__": + parser = get_parser() + opts = parser.parse_args() + + if opts.exp_name is None: + prefix = datetime.now().strftime("%Y%m%d_%H%M") + opts.exp_name = f'{prefix}_monst3r' + + opts.save_dir = os.path.join(opts.out_dir, opts.exp_name) + os.makedirs(opts.save_dir, exist_ok=True) + + monst3r = MonST3R_DataEngine(opts) + + json_path = '/lpai/volumes/ad-vla-vol-ga/chenantong/datasets/NuScenes/annotation/nuScenes_frontview_caminex_fullframes.json' + data_root = "/lpai/dataset/nuscenes-imgs/0-1-0/nuscenes" + + monst3r.process_dataset(json_path, data_root) \ No newline at end of file diff --git a/monst3r/pvd_utils.py b/monst3r/pvd_utils.py new file mode 100755 index 0000000..827e5af --- /dev/null +++ b/monst3r/pvd_utils.py @@ -0,0 +1,789 @@ +import trimesh +import torch +import numpy as np +import os +import math +import torchvision +import scipy +from tqdm import tqdm +import cv2 # Assuming OpenCV is used for image saving +from PIL import Image +import pytorch3d +import random +from PIL import ImageGrab +torchvision +from torchvision.utils import save_image +from pytorch3d.renderer import ( + PointsRasterizationSettings, + PointsRenderer, + PointsRasterizer, + AlphaCompositor, + PerspectiveCameras, +) +import imageio +import torch.nn.functional as F +from torchvision.transforms import ToPILImage +import copy +from scipy.interpolate import interp1d +from scipy.interpolate import UnivariateSpline +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +import sys + +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from torchvision.transforms import CenterCrop, Compose, Resize + +def save_video(data,images_path,folder=None): + if isinstance(data, np.ndarray): + tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8) + elif isinstance(data, torch.Tensor): + tensor_data = (data.detach().cpu() * 255).to(torch.uint8) + elif isinstance(data, list): + folder = [folder]*len(data) + images = [np.array(Image.open(os.path.join(folder_name,path))) for folder_name,path in zip(folder,data)] + stacked_images = np.stack(images, axis=0) + tensor_data = torch.from_numpy(stacked_images).to(torch.uint8) + torchvision.io.write_video(images_path, tensor_data, fps=8, video_codec='h264', options={'crf': '10'}) + +def save_frames(data, output_dir, save_names=None, folder=None): + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + if isinstance(data, np.ndarray): + tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8) + elif isinstance(data, torch.Tensor): + tensor_data = (data.detach().cpu() * 255).to(torch.uint8) + elif isinstance(data, list): + folder = [folder]*len(data) + images = [np.array(Image.open(os.path.join(folder_name,path))) for folder_name,path in zip(folder,data)] + stacked_images = np.stack(images, axis=0) + tensor_data = torch.from_numpy(stacked_images).to(torch.uint8) + + # Save each frame as a separate image + for i in range(tensor_data.shape[0]): + frame = tensor_data[i] + frame_path = os.path.join(output_dir, f'{save_names[i]}.png') + Image.fromarray(frame.numpy()).save(frame_path) + +def get_input_dict(img_tensor,idx,dtype = torch.float32): + + return {'img':F.interpolate(img_tensor.to(dtype), size=(288, 512), mode='bilinear', align_corners=False), 'true_shape': np.array([[288, 512]], dtype=np.int32), 'idx': idx, 'instance': str(idx), 'img_ori':img_tensor.to(dtype)} + # return {'img':F.interpolate(img_tensor.to(dtype), size=(288, 512), mode='bilinear', align_corners=False), 'true_shape': np.array([[288, 512]], dtype=np.int32), 'idx': idx, 'instance': str(idx), 'img_ori':ToPILImage()((img_tensor.squeeze(0)+ 1) / 2)} + + +def rotate_theta(c2ws_input, theta, phi, r, device): + # theta: 图像的倾角,新的y’轴(位于yoz平面)与y轴的夹角 + #让相机在以[0,0,depth_avg]为球心的球面上运动,可以先让其在[0,0,0]为球心的球面运动,方便计算旋转矩阵,之后在平移 + c2ws = copy.deepcopy(c2ws_input) + c2ws[:,2, 3] = c2ws[:,2, 3] + r #将相机坐标系沿着世界坐标系-z方向平移r + # 计算旋转向量 + theta = torch.deg2rad(torch.tensor(theta)).to(device) + phi = torch.deg2rad(torch.tensor(phi)).to(device) + v = torch.tensor([0, torch.cos(theta), torch.sin(theta)]) + # 计算反对称矩阵 + v_x = torch.zeros(3, 3).to(device) + v_x[0, 1] = -v[2] + v_x[0, 2] = v[1] + v_x[1, 0] = v[2] + v_x[1, 2] = -v[0] + v_x[2, 0] = -v[1] + v_x[2, 1] = v[0] + + # 计算反对称矩阵的平方 + v_x_square = torch.matmul(v_x, v_x) + + # 计算旋转矩阵 + R = torch.eye(3).to(device) + torch.sin(phi) * v_x + (1 - torch.cos(phi)) * v_x_square + + # 转换为齐次表示 + R_h = torch.eye(4) + R_h[:3, :3] = R + Rot_mat = R_h.to(device) + + c2ws = torch.matmul(Rot_mat, c2ws) + c2ws[:,2, 3]= c2ws[:,2, 3] - r #最后减去r,相当于绕着z=|r|为中心旋转 + + return c2ws + +def sphere2pose(c2ws_input, theta, phi, r, device,x=None,y=None): + c2ws = copy.deepcopy(c2ws_input) + + #先沿着世界坐标系z轴方向平移再旋转 + c2ws[:,2,3] += r + if x is not None: + c2ws[:,1,3] += y + if y is not None: + c2ws[:,0,3] += x + + theta = torch.deg2rad(torch.tensor(theta)).to(device) + sin_value_x = torch.sin(theta) + cos_value_x = torch.cos(theta) + rot_mat_x = torch.tensor([[1, 0, 0, 0], + [0, cos_value_x, -sin_value_x, 0], + [0, sin_value_x, cos_value_x, 0], + [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) + + phi = torch.deg2rad(torch.tensor(phi)).to(device) + sin_value_y = torch.sin(phi) + cos_value_y = torch.cos(phi) + rot_mat_y = torch.tensor([[cos_value_y, 0, sin_value_y, 0], + [0, 1, 0, 0], + [-sin_value_y, 0, cos_value_y, 0], + [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) + + c2ws = torch.matmul(rot_mat_x,c2ws) + c2ws = torch.matmul(rot_mat_y,c2ws) + + return c2ws + +def generate_candidate_poses(c2ws_anchor,H,W,fs,c,theta, phi,num_candidates,device): + # Initialize a camera. + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + if num_candidates == 2: + thetas = np.array([0,-theta]) + phis = np.array([phi,phi]) + elif num_candidates == 3: + thetas = np.array([0,-theta,theta/2.]) #avoid too many downward + phis = np.array([phi,phi,phi]) + else: + raise ValueError("NBV mode only supports 2 or 3 candidates per iteration.") + + c2ws_list = [] + + for th, ph in zip(thetas,phis): + c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), r=None, device= device) + c2ws_list.append(c2w_new) + c2ws = torch.cat(c2ws_list,dim=0) + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,thetas,phis + +def interpolate_poses_spline(poses, n_interp, spline_degree=5, + smoothness=.03, rot_weight=.1): + """Creates a smooth spline path between input keyframe camera poses. + + Spline is calculated with poses in format (position, lookat-point, up-point). + + Args: + poses: (n, 3, 4) array of input pose keyframes. + n_interp: returned path will have n_interp * (n - 1) total poses. + spline_degree: polynomial degree of B-spline. + smoothness: parameter for spline smoothing, 0 forces exact interpolation. + rot_weight: relative weighting of rotation/translation in spline solve. + + Returns: + Array of new camera poses with shape (n_interp * (n - 1), 3, 4). + """ + + def poses_to_points(poses, dist): + """Converts from pose matrices to (position, lookat, up) format.""" + pos = poses[:, :3, -1] + lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] + up = poses[:, :3, -1] + dist * poses[:, :3, 1] + return np.stack([pos, lookat, up], 1) + + def points_to_poses(points): + """Converts from (position, lookat, up) format to pose matrices.""" + return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) + + def interp(points, n, k, s): + """Runs multidimensional B-spline interpolation on the input points.""" + sh = points.shape + pts = np.reshape(points, (sh[0], -1)) + k = min(k, sh[0] - 1) + tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) + u = np.linspace(0, 1, n, endpoint=False) + new_points = np.array(scipy.interpolate.splev(u, tck)) + new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) + return new_points + + def viewmatrix(lookdir, up, position): + """Construct lookat view matrix.""" + vec2 = normalize(lookdir) + vec0 = normalize(np.cross(up, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, position], axis=1) + return m + + def normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x) + + points = poses_to_points(poses, dist=rot_weight) + new_points = interp(points, + n_interp * (points.shape[0] - 1), + k=spline_degree, + s=smoothness) + new_poses = points_to_poses(new_points) + poses_tensor = torch.from_numpy(new_poses) + extra_row = torch.tensor(np.repeat([[0, 0, 0, 1]], n_interp, axis=0), dtype=torch.float32).unsqueeze(1) + poses_final = torch.cat([poses_tensor, extra_row], dim=1) + + return poses_final + +def interp_traj(c2ws: torch.Tensor, n_inserts: int = 25, device='cuda') -> torch.Tensor: + + n_poses = c2ws.shape[0] + interpolated_poses = [] + + for i in range(n_poses-1): + start_pose = c2ws[i] + end_pose = c2ws[(i + 1) % n_poses] + interpolated_path = interpolate_poses_spline(torch.stack([start_pose, end_pose])[:, :3, :].cpu().numpy(), n_inserts).to(device) + interpolated_path = interpolated_path[:-1] + interpolated_poses.append(interpolated_path) + + interpolated_poses.append(c2ws[-1:]) + full_path = torch.cat(interpolated_poses, dim=0) + + return full_path + +def generate_traj_interp(c2ws,H,W,fs,c,ns,device): + + c2ws = interp_traj(c2ws,n_inserts= ns,device=device) + num_views = c2ws.shape[0] + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + + fs = interpolate_sequence(fs,ns-2,device=device) + c = interpolate_sequence(c,ns-2,device=device) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + + return cameras, num_views + +def generate_traj_specified(c2ws_anchor,H,W,fs,c,theta, phi,d_r,d_x,d_y,frame,device): + # Initialize a camera. + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + + thetas = np.linspace(0,theta,frame) + phis = np.linspace(0,phi,frame) + if isinstance(c2ws_anchor,torch.Tensor): + rs = np.linspace(0,d_r*c2ws_anchor[0,2,3].cpu(),frame) + xs = np.linspace(0,d_x.cpu(),frame) + ys = np.linspace(0,d_y.cpu(),frame) + else: + rs = np.linspace(0,d_r*c2ws_anchor[0,2,3],frame) + xs = np.linspace(0,d_x,frame) + ys = np.linspace(0,d_y,frame) + c2ws_list = [] + for th, ph, r, x, y in zip(thetas,phis,rs, xs, ys): + c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device, np.float32(x),np.float32(y)) + c2ws_list.append(c2w_new) + c2ws = torch.cat(c2ws_list,dim=0) + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,num_views + +def generate_traj_txt(c2ws_anchor,H,W,fs,c,phi, theta, r,frame,device,viz_traj=False, save_dir = None): + # Initialize a camera. + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + + if len(phi)>3: + phis = txt_interpolation(phi,frame,mode='smooth') + phis[0] = phi[0] + phis[-1] = phi[-1] + else: + phis = txt_interpolation(phi,frame,mode='linear') + + if len(theta)>3: + thetas = txt_interpolation(theta,frame,mode='smooth') + thetas[0] = theta[0] + thetas[-1] = theta[-1] + else: + thetas = txt_interpolation(theta,frame,mode='linear') + + if len(r) >3: + rs = txt_interpolation(r,frame,mode='smooth') + rs[0] = r[0] + rs[-1] = r[-1] + else: + rs = txt_interpolation(r,frame,mode='linear') + rs = rs*c2ws_anchor[0,2,3].cpu().numpy() + + c2ws_list = [] + for th, ph, r in zip(thetas,phis,rs): + c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device) + c2ws_list.append(c2w_new) + c2ws = torch.cat(c2ws_list,dim=0) + + if viz_traj: + poses = c2ws.cpu().numpy() + # visualizer(poses, os.path.join(save_dir,'viz_traj.png')) + frames = [visualizer_frame(poses, i) for i in range(len(poses))] + save_video(np.array(frames)/255.,os.path.join(save_dir,'viz_traj.mp4')) + + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,num_views + +def setup_renderer(cameras, image_size): + # Define the settings for rasterization and shading. + raster_settings = PointsRasterizationSettings( + image_size=image_size, + radius = 0.01, + points_per_pixel = 10, + bin_size = 0 + ) + + renderer = PointsRenderer( + rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings), + compositor=AlphaCompositor() + ) + + render_setup = {'cameras': cameras, 'raster_settings': raster_settings, 'renderer': renderer} + + return render_setup + +def interpolate_sequence(sequence, k,device): + + N, M = sequence.size() + weights = torch.linspace(0, 1, k+1).view(1, -1, 1).to(device) + left_values = sequence[:-1].unsqueeze(1).repeat(1, k+1, 1) + right_values = sequence[1:].unsqueeze(1).repeat(1, k+1, 1) + new_sequence = torch.einsum("ijk,ijl->ijl", (1 - weights), left_values) + torch.einsum("ijk,ijl->ijl", weights, right_values) + new_sequence = new_sequence.reshape(-1, M) + new_sequence = torch.cat([new_sequence, sequence[-1].view(1, -1)], dim=0) + return new_sequence + +def focus_point_fn(c2ws: torch.Tensor) -> torch.Tensor: + """Calculate nearest point to all focal axes in camera-to-world matrices.""" + # Extract camera directions and origins from c2ws + directions, origins = c2ws[:, :3, 2:3], c2ws[:, :3, 3:4] + m = torch.eye(3).to(c2ws.device) - directions * torch.transpose(directions, 1, 2) + mt_m = torch.transpose(m, 1, 2) @ m + focus_pt = torch.inverse(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] + return focus_pt + +def generate_camera_path(c2ws: torch.Tensor, n_inserts: int = 15, device='cuda') -> torch.Tensor: + n_poses = c2ws.shape[0] + interpolated_poses = [] + + for i in range(n_poses-1): + start_pose = c2ws[i] + end_pose = c2ws[(i + 1) % n_poses] + focus_point = focus_point_fn(torch.stack([start_pose,end_pose])) + interpolated_path = interpolate_poses(start_pose, end_pose, focus_point, n_inserts, device) + + # Exclude the last pose (end_pose) for all pairs + interpolated_path = interpolated_path[:-1] + + interpolated_poses.append(interpolated_path) + # Concatenate all the interpolated paths + interpolated_poses.append(c2ws[-1:]) + full_path = torch.cat(interpolated_poses, dim=0) + return full_path + +def interpolate_poses(start_pose: torch.Tensor, end_pose: torch.Tensor, focus_point: torch.Tensor, n_inserts: int = 15, device='cuda') -> torch.Tensor: + dtype = start_pose.dtype + start_distance = torch.sqrt((start_pose[0, 3] - focus_point[0])**2 + (start_pose[1, 3] - focus_point[1])**2 + (start_pose[2, 3] - focus_point[2])**2) + end_distance = torch.sqrt((end_pose[0, 3] - focus_point[0])**2 + (end_pose[1, 3] - focus_point[1])**2 + (end_pose[2, 3] - focus_point[2])**2) + start_rot = R.from_matrix(start_pose[:3, :3].cpu().numpy()) + end_rot = R.from_matrix(end_pose[:3, :3].cpu().numpy()) + slerp_obj = Slerp([0, 1], R.from_quat([start_rot.as_quat(), end_rot.as_quat()])) + + inserted_c2ws = [] + + for t in torch.linspace(0., 1., n_inserts + 2, dtype=dtype): # Exclude the first and last point + interpolated_rot = slerp_obj(t).as_matrix() + interpolated_translation = (1 - t) * start_pose[:3, 3] + t * end_pose[:3, 3] + interpolated_distance = (1 - t) * start_distance + t * end_distance + direction = (interpolated_translation - focus_point) / torch.norm(interpolated_translation - focus_point) + interpolated_translation = focus_point + direction * interpolated_distance + + inserted_pose = torch.eye(4, dtype=dtype).to(device) + inserted_pose[:3, :3] = torch.from_numpy(interpolated_rot).to(device) + inserted_pose[:3, 3] = interpolated_translation + inserted_c2ws.append(inserted_pose) + + path = torch.stack(inserted_c2ws) + return path + + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + +def save_pointcloud_with_normals(imgs, pts3d, msk, save_path, mask_pc, reduce_pc): + pc = get_pc(imgs, pts3d, msk,mask_pc,reduce_pc) # Assuming get_pc is defined elsewhere and returns a trimesh point cloud + + # Define a default normal, e.g., [0, 1, 0] + default_normal = [0, 1, 0] + + # Prepare vertices, colors, and normals for saving + vertices = pc.vertices + colors = pc.colors + normals = np.tile(default_normal, (vertices.shape[0], 1)) + + # Construct the header of the PLY file + header = """ply +format ascii 1.0 +element vertex {} +property float x +property float y +property float z +property uchar red +property uchar green +property uchar blue +property float nx +property float ny +property float nz +end_header +""".format(len(vertices)) + + # Write the PLY file + with open(save_path, 'w') as ply_file: + ply_file.write(header) + for vertex, color, normal in zip(vertices, colors, normals): + ply_file.write('{} {} {} {} {} {} {} {} {}\n'.format( + vertex[0], vertex[1], vertex[2], + int(color[0]), int(color[1]), int(color[2]), + normal[0], normal[1], normal[2] + )) + + +def get_pc(imgs, pts3d, mask, mask_pc=False, reduce_pc=False): + imgs = to_numpy(imgs) + pts3d = to_numpy(pts3d) + mask = to_numpy(mask) + + if mask_pc: + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) + else: + pts = np.concatenate([p for p in pts3d]) + col = np.concatenate([p for p in imgs]) + + if reduce_pc: + pts = pts.reshape(-1, 3)[::3] + col = col.reshape(-1, 3)[::3] + else: + pts = pts.reshape(-1, 3) + col = col.reshape(-1, 3) + + #mock normals: + normals = np.tile([0, 1, 0], (pts.shape[0], 1)) + + pct = trimesh.PointCloud(pts, colors=col) + # debug + # pct.export('output.ply') + # print('exporting output.ply') + pct.vertices_normal = normals # Manually add normals to the point cloud + + return pct#, pts + +def world_to_kth(poses, k): + # 将世界坐标系转到和第k个pose的相机坐标系一致 + kth_pose = poses[k] + inv_kth_pose = torch.inverse(kth_pose) + new_poses = torch.bmm(inv_kth_pose.unsqueeze(0).expand_as(poses), poses) + return new_poses + +def world_point_to_kth(poses, points, k, device): + # 将世界坐标系转到和第k个pose的相机坐标系一致,同时处理点云 + kth_pose = poses[k] + inv_kth_pose = torch.inverse(kth_pose) + # 给所有pose左成kth_w2c,将其都变到kth_pose的camera coordinate下 + new_poses = torch.bmm(inv_kth_pose.unsqueeze(0).expand_as(poses), poses) + N, W, H, _ = points.shape + points = points.view(N, W * H, 3) + homogeneous_points = torch.cat([points, torch.ones(N, W*H, 1).to(device)], dim=-1) + new_points = inv_kth_pose.unsqueeze(0).expand(N, -1, -1).unsqueeze(1)@ homogeneous_points.unsqueeze(-1) + new_points = new_points.squeeze(-1)[...,:3].view(N, W, H, _) + + return new_poses, new_points + + +def world_point_to_obj(poses, points, k, r, elevation, device): + ## 作用:将世界坐标系转到object的中心 + + ## 先将世界坐标系转到指定相机 + poses, points = world_point_to_kth(poses, points, k, device) + + ## 定义目标坐标系位姿, 原点位于object中心(远世界坐标系[0,0,r]),Y轴向上, Z轴垂直屏幕向外, X轴向右 + elevation_rad = torch.deg2rad(torch.tensor(180-elevation)).to(device) + sin_value_x = torch.sin(elevation_rad) + cos_value_x = torch.cos(elevation_rad) + R = torch.tensor([[1, 0, 0,], + [0, cos_value_x, sin_value_x], + [0, -sin_value_x, cos_value_x]]).to(device) + + t = torch.tensor([0, 0, r]).to(device) + pose_obj = torch.eye(4).to(device) + pose_obj[:3, :3] = R + pose_obj[:3, 3] = t + + ## 给所有点和pose乘以目标坐标系的逆(w2c),将它们变换到目标坐标系下 + inv_obj_pose = torch.inverse(pose_obj) + new_poses = torch.bmm(inv_obj_pose.unsqueeze(0).expand_as(poses), poses) + + N, W, H, _ = points.shape + points = points.view(N, W * H, 3) + homogeneous_points = torch.cat([points, torch.ones(N, W*H, 1).to(device)], dim=-1) + new_points = inv_obj_pose.unsqueeze(0).expand(N, -1, -1).unsqueeze(1)@ homogeneous_points.unsqueeze(-1) + new_points = new_points.squeeze(-1)[...,:3].view(N, W, H, _) + + return new_poses, new_points + +def world_point_to_obj_nuscenes(poses, points, c2ws, k, r, elevation, device): + """ + Transform camera poses and points to object-centric coordinates while maintaining RDF system. + + Args: + poses: [N, 4, 4] Camera poses in RDF coordinate system(nuscenes coordinate system) + points: [M, H, W, 3] 3D points or None, where M may not equal N(dust coordinate system) + k: Index of reference camera + r: Radius for centering + elevation: Camera elevation angle in degrees + device: torch device + + Returns: + new_poses: [N, 4, 4] Transformed camera poses + new_points: [M, H, W, 3] Transformed points (if points provided) + """ + def calculate_mean_scale(poses2kth_pose, dust2kth): + if len(poses2kth_pose) <= 1 or len(dust2kth) <= 1: + return 1. + nuscenes_trans = poses2kth_pose[1:, :3, 3] + dust3r_trans = dust2kth[1:, :3, 3] + scales = dust3r_trans.norm(dim=-1) / (nuscenes_trans.norm(dim=-1) + 1e-6) + return scales.mean() + + # 1. First transform everything to kth camera coordinate system + kth_pose = poses[k].to(torch.float64).to(device) # Reference pose + inv_kth_pose = torch.inverse(kth_pose).to(torch.float64) + kth_pose_dust = c2ws[k].to(torch.float64).to(device) + inv_kth_pose_dust = torch.inverse(kth_pose_dust).to(torch.float64) + # kth_pose2dust = kth_pose_dust @ inv_kth_pose + + + # Transform all poses to reference camera space + poses2kth_pose = torch.bmm( + inv_kth_pose.unsqueeze(0).expand_as(poses.to(torch.float64)), + poses.to(torch.float64) + ) + # dust all poses to reference camera space + dust2kth = torch.bmm( + inv_kth_pose_dust.unsqueeze(0).expand_as(c2ws.to(torch.float64)), + c2ws.to(torch.float64) + ) + # scale factor of nuscenes to dust3r--dust3r / nuscenes + scale_factor = calculate_mean_scale(poses2kth_pose, dust2kth) + poses2kth_pose_new = poses2kth_pose + poses2kth_pose_new[:, :3, 3] *= scale_factor + + # 2. Create object-centric transformation + # Convert elevation to radians and create rotation matrix + elevation_rad = torch.deg2rad(torch.tensor(180-elevation)).to(device) + sin_value = torch.sin(elevation_rad) + cos_value = torch.cos(elevation_rad) + + # Rotation matrix for RDF coordinate system + # Rotates around right (x) axis to achieve desired elevation + R = torch.tensor([ + [1, 0, 0], + [0, cos_value, -sin_value], # Note: signs adjusted for RDF + [0, sin_value, cos_value] + ], device=device) + + # Translation to center object at distance r along forward axis + t = torch.tensor([0, 0, r], device=device) + + # Combine into transformation matrix + pose_obj = torch.eye(4, device=device) + pose_obj[:3, :3] = R + pose_obj[:3, 3] = t + + # 3. Apply object-centric transformation + inv_obj_pose = torch.inverse(pose_obj).to(torch.float64) + new_poses = torch.bmm( + inv_obj_pose.unsqueeze(0).expand_as(poses2kth_pose_new), + poses2kth_pose_new + ) + + # 4. Transform points if provided + if points is not None: + M, H, W, _ = points.shape + points_reshaped = points.view(M, H * W, 3) + + # Add homogeneous coordinate + homogeneous_points = torch.cat([ + points_reshaped, + torch.ones(M, H*W, 1, device=device) + ], dim=-1) + + # First transform to reference camera space + points_cam = torch.bmm( + inv_kth_pose_dust.unsqueeze(0).expand(M, -1, -1), + homogeneous_points.transpose(1, 2).to(torch.float64) + ).transpose(1, 2) + + # Then transform to object space + new_points = torch.bmm( + inv_obj_pose.unsqueeze(0).expand(M, -1, -1), + points_cam.transpose(1, 2).to(torch.float64) # Transpose to match batch matrix multiply dimensions + ).transpose(1, 2) # Transpose back to original shape + + # Remove homogeneous coordinate and reshape + new_points = new_points[..., :3].view(M, H, W, 3) + + return new_poses.to(torch.float32), new_points.to(torch.float32) + + return new_poses.to(torch.float32) + +def txt_interpolation(input_list,n,mode = 'smooth'): + x = np.linspace(0, 1, len(input_list)) + if mode == 'smooth': + f = UnivariateSpline(x, input_list, k=3) + elif mode == 'linear': + f = interp1d(x, input_list) + else: + raise KeyError(f"Invalid txt interpolation mode: {mode}") + xnew = np.linspace(0, 1, n) + ynew = f(xnew) + return ynew + +def visualizer(camera_poses, save_path="out.png"): + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + colors = ["blue" for _ in camera_poses] + for pose, color in zip(camera_poses, colors): + + camera_positions = pose[:3, 3] + ax.scatter( + camera_positions[0], + camera_positions[1], + camera_positions[2], + c=color, + marker="o", + ) + + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + ax.set_title("Camera trajectory") + # ax.view_init(90+30, -90) + plt.savefig(save_path) + plt.close() + +def visualizer_frame(camera_poses, highlight_index): + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + # 获取camera_positions[2]的最大值和最小值 + z_values = [pose[:3, 3][2] for pose in camera_poses] + z_min, z_max = min(z_values), max(z_values) + + # 创建一个颜色映射对象 + cmap = mcolors.LinearSegmentedColormap.from_list("mycmap", ["#00008B", "#ADD8E6"]) + # cmap = plt.get_cmap("coolwarm") + norm = mcolors.Normalize(vmin=z_min, vmax=z_max) + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + + for i, pose in enumerate(camera_poses): + camera_positions = pose[:3, 3] + color = "blue" if i == highlight_index else "blue" + size = 100 if i == highlight_index else 25 + color = sm.to_rgba(camera_positions[2]) # 根据camera_positions[2]的值映射颜色 + ax.scatter( + camera_positions[0], + camera_positions[1], + camera_positions[2], + c=color, + marker="o", + s=size, + ) + + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + # ax.set_title("Camera trajectory") + ax.view_init(90+30, -90) + + plt.ylim(-0.1,0.2) + fig.canvas.draw() + width, height = fig.canvas.get_width_height() + + img = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8').reshape(height, width, 3) + # new_width = int(width * 0.6) + # start_x = (width - new_width) // 2 + new_width // 5 + # end_x = start_x + new_width + # img = img[:, start_x:end_x, :] + + + plt.close() + + return img + + +def center_crop_image(input_image): + + height = 576 + width = 1024 + _,_,h,w = input_image.shape + h_ratio = h / height + w_ratio = w / width + + if h_ratio > w_ratio: + h = int(h / w_ratio) + if h < height: + h = height + input_image = Resize((h, width))(input_image) + + else: + w = int(w / h_ratio) + if w < width: + w = width + input_image = Resize((height, w))(input_image) + + transformer = Compose([ + # Resize(width), + CenterCrop((height, width)), + ]) + + input_image = transformer(input_image) + return input_image + diff --git a/monst3r/render.py b/monst3r/render.py new file mode 100644 index 0000000..0e0c9b4 --- /dev/null +++ b/monst3r/render.py @@ -0,0 +1,505 @@ +# -------------------------------------------------------- +# gradio demo +# -------------------------------------------------------- + +import argparse +from fileinput import fileno +import math +from struct import calcsize +from altair import Theta +import gradio +import os +from sympy import ask +import torch +import numpy as np +import tempfile +import functools +import copy + +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images, rgb, enlarge_seg_masks +from dust3r.utils.device import to_numpy +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode +from dust3r.utils.viz_demo import convert_scene_output_to_glb, get_dynamic_mask_from_pairviewer +import matplotlib.pyplot as pl +import cv2 + +from pvd_utils import world_point_to_obj, sphere2pose, generate_traj_specified, setup_renderer, save_video, save_frames + +pl.ion() +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +## render imports +from pytorch3d.renderer import PerspectiveCameras +from pytorch3d.structures import Pointclouds +from PIL import Image +import torchvision +import torch.nn.functional as F +from tqdm import tqdm +import time + +def render_pcd(pts3d,imgs,masks,views,renderer,device='cpu',nbv=False): + # Placeholder for rendering point cloud + imgs = to_numpy(imgs) + pts3d = to_numpy(pts3d) + + if masks == None: + pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) + col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) + else: + # masks = to_numpy(masks) + pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) + col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) + + point_cloud = Pointclouds(points=[pts], features=[col]).extend(views) + images = renderer(point_cloud) + + if nbv: + color_mask = torch.ones(col.shape).to(device) + point_cloud_mask = Pointclouds(points=[pts],features=[color_mask]).extend(views) + view_masks = renderer(point_cloud_mask) + else: + view_masks = None + + return images, view_masks + +def run_render(pcd, imgs,masks, H, W, camera_traj,num_views,device='cpu', nbv=False): + # Placeholder for running the rendering process + render_setup = setup_renderer(camera_traj, image_size=(H,W)) + renderer = render_setup['renderer'] + render_results, viewmask = render_pcd(pcd, imgs, masks, num_views,renderer,device=device,nbv=nbv) + return render_results, viewmask + +def generate_traj_from_single_c2w(c2ws_anchor,H,W,fs,c,theta=0., phi=0.,d_r=0.,d_x=0.,d_y=0.,frame=1,device='cuda'): + # Placeholder for generating trajectory from a single camera to world transformation + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + c2w_new = sphere2pose(c2ws_anchor, np.float32(theta), np.float32(phi), np.float32(d_r), device, np.float32(d_x),np.float32(d_y)) + c2ws = c2w_new + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,num_views + + +def generate_traj_from_c2ws(c2ws_anchor,H,W,fs,c,frame,device): + # Placeholder for generating trajectory from multiple camera to world transformations + # Initialize a camera. + """ + The camera coordinate sysmte in COLMAP is right-down-forward + Pytorch3D is left-up-forward + """ + theta = 0 + phi = 0 + r = 1 + x = 0 + y = 0 + + c2ws_list = [] + for i in range(frame): + c2w_new = sphere2pose(c2ws_anchor[i:i+1], np.float32(theta), np.float32(phi), np.float32(r), device, np.float32(x),np.float32(y)) + c2ws_list.append(c2w_new) + c2ws = torch.cat(c2ws_list,dim=0) + num_views = c2ws.shape[0] + + R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] + ## 将dust3r坐标系转成pytorch3d坐标系 + R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation + new_c2w = torch.cat([R, T], 2) + w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) + R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix + image_size = ((H, W),) # (h, w) + cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) + return cameras,num_views + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser_url = parser.add_mutually_exclusive_group() + parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") + parser.add_argument("--weights", type=str, help="path to the model weights", default='checkpoints/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt.pth') + parser.add_argument("--model_name", type=str, default='Junyi42/MonST3R_PO-TA-S-W_ViTLarge_BaseDecoder_512_dpt', help="model name") + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--output_dir", type=str, default='./demo_tmp', help="value for tempfile.tempdir") + parser.add_argument("--silent", action='store_true', default=False, + help="silence logs") + parser.add_argument("--input_dir", type=str, help="Path to input images directory", default=None) + parser.add_argument("--seq_name", type=str, help="Sequence name for evaluation", default='NULL') + parser.add_argument('--use_gt_davis_masks', action='store_true', default=False, help='Use ground truth masks for DAVIS') + parser.add_argument('--not_batchify', action='store_true', default=False, help='Use non batchify mode for global optimization') + parser.add_argument('--real_time', action='store_true', default=False, help='Realtime mode') + parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') + + parser.add_argument('--fps', type=int, default=0, help='FPS for video processing') + parser.add_argument('--num_frames', type=int, default=200, help='Maximum number of frames for video processing') + + # render options + parser.add_argument('--total_sections', type=int, default=1, help='Total sections for rendering') + parser.add_argument('--section', type=int, default=0, help='Section number for rendering') + + return parser + +def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, + clean_depth=False, transparent_cams=False, cam_size=0.05, show_cam=True, save_name=None, thr_for_init_conf=True): + # Placeholder for getting 3D model from scene + """ + extract 3D_model (glb file) from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + focals = scene.get_focals().cpu() + cams2world = scene.get_im_poses().cpu() + # 3D pointcloud from depthmap, poses and intrinsics + pts3d = to_numpy(scene.get_pts3d(raw_pts=True)) + scene.min_conf_thr = min_conf_thr + scene.thr_for_init_conf = thr_for_init_conf + msk = to_numpy(scene.get_masks()) + cmap = pl.get_cmap('viridis') + cam_color = [cmap(i/len(rgbimg))[:3] for i in range(len(rgbimg))] + cam_color = [(255*c[0], 255*c[1], 255*c[2]) for c in cam_color] + + return convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, + transparent_cams=transparent_cams, cam_size=cam_size, show_cam=show_cam, silent=silent, save_name=save_name, + cam_color=cam_color) + +def render_2D_images_from_scene(outdir, scene, min_conf_thr=3, mask_sky=False, + clean_depth=False, thr_for_init_conf=True): + # Placeholder for rendering 2D images from the scene + """ + render 2D images from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + # scene.min_conf_thr = min_conf_thr + # scene.thr_for_init_conf = thr_for_init_conf + # msk = to_numpy(scene.get_masks()) + msk = None + + # DEBUG, render images + elevation = 5.0 + center_scale = 1.0 + d_x = torch.tensor(0.0) + d_y = torch.tensor(0.0) + d_r = torch.tensor(0.0) + d_theta = torch.tensor(0.0) + d_phi = torch.tensor(0.0) + dpt_trd = torch.tensor(1.0) + device = torch.device('cuda') + H, W = rgbimg[0].shape[:2] + + c2ws = scene.get_im_poses().detach() + pcd = [i.detach() for i in scene.get_pts3d(clip_thred=dpt_trd, raw_pts=True)] # a list of points of size whc + focals = scene.get_focals().detach() + principal_points = scene.get_principal_points().detach() + depth = to_numpy(scene.get_depthmaps()) + depth_avg = depth[0][H//2,W//2] + radius = depth_avg*center_scale + + c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=0, r=radius, elevation=elevation, device=device) + + # Apply the fixed view transformation to all camera poses + camera_poses = [] + for i in range(len(c2ws)): + # Generate single transformed pose + single_pose, _ = generate_traj_from_single_c2w( + c2ws[i:i+1], # Take one pose at a time + H, W, + focals[0:1], # Use first frame's focal length for all + principal_points[0:1], # Use first frame's principal point for all + frame=1, # Only generate one frame per pose + device=device + ) + camera_poses.append(single_pose) + + # Combine all transformed cameras + camera_traj = camera_poses[0] + for cam in camera_poses[1:]: + camera_traj.R = torch.cat([camera_traj.R, cam.R]) + camera_traj.T = torch.cat([camera_traj.T, cam.T]) + num_views = camera_traj.R.shape[0] + + # use first frame point cloud and image + render_results, viewmask = run_render(pcd[0:1], rgbimg[0:1], msk, H, W, camera_traj,num_views, device=device, nbv=False) + + # put ref image into the first position + render_results[0] = torch.from_numpy(rgbimg[0]) + + render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) + + save_frames(render_results, outdir, save_names=[f'render_{i}' for i in range(len(render_results))], folder=outdir) + save_video(render_results, os.path.join(outdir, 'render.mp4')) + + return render_results + +def get_reconstructed_scene(args, outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, show_cam, scenegraph_type, winsize, refid, + seq_name, new_model_weights, temporal_smoothing_weight, translation_weight, shared_focal, + flow_loss_weight, flow_loss_start_iter, flow_loss_threshold, use_gt_mask, fps, num_frames): + # Placeholder for getting reconstructed scene + translation_weight = float(translation_weight) + if new_model_weights != args.weights: + model = AsymmetricCroCo3DStereo.from_pretrained(new_model_weights).to(device) + model.eval() + if seq_name != "NULL": + dynamic_mask_path = f'data/davis/DAVIS/masked_images/480p/{seq_name}' + else: + dynamic_mask_path = None + imgs = load_images(filelist, size=image_size, verbose=not silent, dynamic_mask_root=dynamic_mask_path, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + if scenegraph_type == "swin" or scenegraph_type == "swinstride" or scenegraph_type == "swin2stride": + scenegraph_type = scenegraph_type + "-" + str(winsize) + "-noncyclic" + elif scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + if len(imgs) > 2: + mode = GlobalAlignerMode.PointCloudOptimizer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent, shared_focal = shared_focal, temporal_smoothing_weight=temporal_smoothing_weight, translation_weight=translation_weight, + flow_loss_weight=flow_loss_weight, flow_loss_start_epoch=flow_loss_start_iter, flow_loss_thre=flow_loss_threshold, use_self_mask=not use_gt_mask, + num_total_iter=niter, empty_cache= len(filelist) > 72, batchify=not args.not_batchify) + else: + mode = GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent) + lr = 0.01 + + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + + render_results = render_2D_images_from_scene(save_folder, scene, min_conf_thr, mask_sky, + clean_depth, thr_for_init_conf=True) + + # save_folder = f'{save_folder}/reconstruction' #default is 'demo_tmp/NULL' + # poses = scene.save_tum_poses(f'{save_folder}/pred_traj.txt') + # K = scene.save_intrinsics(f'{save_folder}/pred_intrinsics.txt') + # depth_maps = scene.save_depth_maps(save_folder) + # dynamic_masks = scene.save_dynamic_masks(save_folder) + # conf = scene.save_conf_maps(save_folder) + # init_conf = scene.save_init_conf_maps(save_folder) + # rgbs = scene.save_rgb_imgs(save_folder) + # enlarge_seg_masks(save_folder, kernel_size=5 if use_gt_mask else 3) + + # # also return rgb, depth and confidence imgs + # # depth is normalized with the max value for all images + # # we apply the jet colormap on the confidence maps + # rgbimg = scene.imgs + # depths = to_numpy(scene.get_depthmaps()) + # confs = to_numpy([c for c in scene.im_conf]) + # init_confs = to_numpy([c for c in scene.init_conf_maps]) + # cmap = pl.get_cmap('jet') + # depths_max = max([d.max() for d in depths]) + # depths = [cmap(d/depths_max) for d in depths] + # confs_max = max([d.max() for d in confs]) + # confs = [cmap(d/confs_max) for d in confs] + # init_confs_max = max([d.max() for d in init_confs]) + # init_confs = [cmap(d/init_confs_max) for d in init_confs] + + # imgs = [] + # for i in range(len(rgbimg)): + # imgs.append(rgbimg[i]) + # imgs.append(rgb(depths[i])) + # imgs.append(rgb(confs[i])) + # imgs.append(rgb(init_confs[i])) + + # # if two images, and the shape is same, we can compute the dynamic mask + # if len(rgbimg) == 2 and rgbimg[0].shape == rgbimg[1].shape: + # motion_mask_thre = 0.35 + # error_map = get_dynamic_mask_from_pairviewer(scene, both_directions=True, output_dir=args.output_dir, motion_mask_thre=motion_mask_thre) + # # imgs.append(rgb(error_map)) + # # apply threshold on the error map + # normalized_error_map = (error_map - error_map.min()) / (error_map.max() - error_map.min()) + # error_map_max = normalized_error_map.max() + # error_map = cmap(normalized_error_map/error_map_max) + # imgs.append(rgb(error_map)) + # binary_error_map = (normalized_error_map > motion_mask_thre).astype(np.uint8) + # imgs.append(rgb(binary_error_map*255)) + + # used_time = time.time() - start_time + + del scene + torch.cuda.empty_cache() + + return render_results + +def get_reconstructed_scene_realtime(args, model, device, silent, image_size, filelist, scenegraph_type, refid, seq_name, fps, num_frames): + # Placeholder for getting reconstructed scene in real-time + model.eval() + imgs = load_images(filelist, size=image_size, verbose=not silent, fps=fps, num_frames=num_frames) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + + if scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + elif scenegraph_type == "oneref_mid": + scenegraph_type = "oneref-" + str(len(imgs) // 2) + else: + raise ValueError(f"Unknown scenegraph type for realtime mode: {scenegraph_type}") + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=False) + output = inference(pairs, model, device, batch_size=args.batch_size, verbose=not silent) + + save_folder = f'{args.output_dir}/{seq_name}' #default is 'demo_tmp/NULL' + os.makedirs(save_folder, exist_ok=True) + + + view1, view2, pred1, pred2 = output['view1'], output['view2'], output['pred1'], output['pred2'] + pts1 = pred1['pts3d'].detach().cpu().numpy() + pts2 = pred2['pts3d_in_other_view'].detach().cpu().numpy() + for batch_idx in range(len(view1['img'])): + colors1 = rgb(view1['img'][batch_idx]) + colors2 = rgb(view2['img'][batch_idx]) + xyzrgb1 = np.concatenate([pts1[batch_idx], colors1], axis=-1) #(H, W, 6) + xyzrgb2 = np.concatenate([pts2[batch_idx], colors2], axis=-1) + np.save(save_folder + '/pts3d1_p' + str(batch_idx) + '.npy', xyzrgb1) + np.save(save_folder + '/pts3d2_p' + str(batch_idx) + '.npy', xyzrgb2) + + conf1 = pred1['conf'][batch_idx].detach().cpu().numpy() + conf2 = pred2['conf'][batch_idx].detach().cpu().numpy() + np.save(save_folder + '/conf1_p' + str(batch_idx) + '.npy', conf1) + np.save(save_folder + '/conf2_p' + str(batch_idx) + '.npy', conf2) + + # save the imgs of two views + img1_rgb = cv2.cvtColor(colors1 * 255, cv2.COLOR_BGR2RGB) + img2_rgb = cv2.cvtColor(colors2 * 255, cv2.COLOR_BGR2RGB) + cv2.imwrite(save_folder + '/img1_p' + str(batch_idx) + '.png', img1_rgb) + cv2.imwrite(save_folder + '/img2_p' + str(batch_idx) + '.png', img2_rgb) + + return save_folder + + +if __name__ == '__main__': + + parser = get_args_parser() + args = parser.parse_args() + + if args.weights is not None and os.path.exists(args.weights): + weights_path = args.weights + else: + weights_path = args.model_name + + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + + json_path = '/lpai/volumes/ad-vla-vol-ga/chenantong/datasets/NuScenes/annotation/nuScenes_frontview_caminex_fullframes.json' + data_root = "/lpai/dataset/nuscenes-imgs/0-1-0/nuscenes" + frame_cam_corners_samples = torch.load('../new_frame_cam_corners_samples.pth') + frame_cam_samples_indices = torch.load('../frame_cam_samples_indices.pth') + + # Filter already processed samples + rest_frame_cam_corners_samples = [] + rest_frame_cam_samples_indices = [] + for i in range(len(frame_cam_corners_samples)): + idx = frame_cam_samples_indices[i] + save_dir = os.path.join(args.output_dir, f'sample_{idx:06d}') + if os.path.exists(save_dir) and len(os.listdir(save_dir)) >= 25: + print(f'skip {save_dir}') + continue + rest_frame_cam_corners_samples.append(frame_cam_corners_samples[i]) + rest_frame_cam_samples_indices.append(idx) + + # Get section of data to process + def split_and_get_nth(data, n, splits=16): + if n < 0 or n >= splits: + raise ValueError(f"Invalid split index {n}. Must be between 0 and {splits - 1}.") + split_size = len(data) // splits + remainder = len(data) % splits + splits_indices = [] + start_idx = 0 + for i in range(splits): + extra = 1 if i < remainder else 0 + splits_indices.append((start_idx, start_idx + split_size + extra)) + start_idx += split_size + extra + start, end = splits_indices[n] + return data[start:end] + + split_corners = split_and_get_nth(rest_frame_cam_corners_samples, args.section, args.total_sections) + split_indices = split_and_get_nth(rest_frame_cam_samples_indices, args.section, args.total_sections) + + # Prepare data for processing + scene_data = [(idx, scene, data_root) for idx, scene in zip(split_indices, split_corners)] + + # Process scenes + times = [] + for data in tqdm(scene_data, desc='Processing scenes'): + + # result, used_time = self.process_single_scene(data) + + sample_idx, scene, data_root = data + frame_files = scene['frames'] + assert len(frame_files) == 25 + frame_paths = [os.path.join(data_root, p) for p in frame_files] + instance_corners = scene['instance_corners'] + + if len(instance_corners) > 0: + print('skip dynamic scene for now...') + continue + + start = time.time() + results = get_reconstructed_scene( + args, args.output_dir, model, args.device, args.silent, args.image_size, + filelist=frame_paths, + schedule='linear', + niter=300, + min_conf_thr=1.1, + as_pointcloud=True, + mask_sky=False, + clean_depth=True, + transparent_cams=False, + cam_size=0.05, + show_cam=True, + scenegraph_type='swin', # swin, swinstride, swin2stride, oneref + winsize=3, + refid=0, + seq_name=f'sample_{sample_idx:06d}', + new_model_weights=args.weights, + temporal_smoothing_weight=0.01, + translation_weight='1.0', + shared_focal=True, + flow_loss_weight=0.01, + flow_loss_start_iter=0.1, + flow_loss_threshold=25, + use_gt_mask=args.use_gt_davis_masks, + fps=args.fps, + num_frames=args.num_frames, + ) + used_time = time.time() - start + print(f"Processing completed. Output saved in {save_dir}/{args.seq_name}") + + if results is not None: + times.append(used_time) + avg_time = sum(times) / len(times) + total_time = avg_time * (len(scene_data) - len(times)) + print(f'Split:{args.section}, time used: {times[-1]:.2f}s, estimated left time: {total_time/3600:.2f}h') \ No newline at end of file diff --git a/monst3r/render_monst3r.sh b/monst3r/render_monst3r.sh new file mode 100644 index 0000000..a2d7568 --- /dev/null +++ b/monst3r/render_monst3r.sh @@ -0,0 +1,33 @@ +# CUDA_VISIBLE_DEVICES=2 python render.py \ +# --output_dir outputs/nuscenes_train_plain_stride1window3 \ +# --device 'cuda' \ +# --batch_size 32 \ +# --image_size 512 \ +# --total_sections 1 \ +# --section 0 + +for gpu in 4 5 6 7; do + section=$((gpu * 1)) + CUDA_VISIBLE_DEVICES=$gpu python render.py \ + --output_dir outputs/nuscenes_train_plain_stride1window3 \ + --device 'cuda' \ + --batch_size 32 \ + --image_size 512 \ + --total_sections 16 \ + --section $section > ../logs/log_${gpu}_${section}.out 2>&1 & + pids+=($!) # Record the process ID +done + +# for gpu in 0 1 2 3 4 5 6 7; do +# for process in {0..1}; do +# section=$((gpu * 2 + process)) +# CUDA_VISIBLE_DEVICES=$gpu python render.py \ +# --output_dir outputs/nuscenes_train_plain \ +# --device 'cuda' \ +# --batch_size 16 \ +# --image_size 512 \ +# --total_sections 16 \ +# --section $section > ../logs/log_${gpu}_${section}.out 2>&1 & +# pids+=($!) # Record the process ID +# done +# done \ No newline at end of file diff --git a/monst3r/render_monst3r_node2.sh b/monst3r/render_monst3r_node2.sh new file mode 100644 index 0000000..eb8e5d9 --- /dev/null +++ b/monst3r/render_monst3r_node2.sh @@ -0,0 +1,33 @@ +CUDA_VISIBLE_DEVICES=0 python render.py \ + --output_dir outputs/nuscenes_train_plain_stride1window3 \ + --device 'cuda' \ + --batch_size 32 \ + --image_size 512 \ + --total_sections 1 \ + --section 0 + +# for gpu in 4 5 6 7; do +# section=$((gpu * 1 + 8)) +# CUDA_VISIBLE_DEVICES=$gpu python render.py \ +# --output_dir outputs/nuscenes_train_plain_stride1window3 \ +# --device 'cuda' \ +# --batch_size 32 \ +# --image_size 512 \ +# --total_sections 16 \ +# --section $section > ../logs/log_${gpu}_${section}.out 2>&1 & +# pids+=($!) # Record the process ID +# done + +# for gpu in 0 1 2 3 4 5 6 7; do +# for process in {0..1}; do +# section=$((gpu * 2 + process)) +# CUDA_VISIBLE_DEVICES=$gpu python render.py \ +# --output_dir outputs/nuscenes_train_plain \ +# --device 'cuda' \ +# --batch_size 16 \ +# --image_size 512 \ +# --total_sections 16 \ +# --section $section > ../logs/log_${gpu}_${section}.out 2>&1 & +# pids+=($!) # Record the process ID +# done +# done \ No newline at end of file diff --git a/monst3r/requirements.txt b/monst3r/requirements.txt new file mode 100644 index 0000000..827de1a --- /dev/null +++ b/monst3r/requirements.txt @@ -0,0 +1,22 @@ +# minimal requirements for inference +torch +torchvision +roma +gradio +matplotlib +tqdm +opencv-python +scipy +einops +gdown +trimesh +pyglet<2 +huggingface-hub[torch]>=0.22 + +# for camera trajectory +evo + +# for sam2, customized version +-e third_party/sam2 + +pyrender \ No newline at end of file diff --git a/monst3r/requirements_optional.txt b/monst3r/requirements_optional.txt new file mode 100644 index 0000000..d491494 --- /dev/null +++ b/monst3r/requirements_optional.txt @@ -0,0 +1,27 @@ +pillow-heif # add heif/heic image support +pyrender # for rendering depths in scannetpp +kapture # for visloc data loading +kapture-localization +numpy-quaternion +pycolmap # for pnp +poselib # for pnp + +# for download tartanair +boto3 +# for preprocess waymo +tensorflow +waymo-open-dataset-tf-2-12-0 --no-deps +# need to change bytearray() to bytes() in waymo_open_dataset package + +# for training +# for logging +wandb +tensorboard +#for pointodyssey +prettytable +scikit-image +scikit-learn +h5py +gdown +# for scannet +pypng \ No newline at end of file diff --git a/monst3r/third_party/RAFT/LICENSE b/monst3r/third_party/RAFT/LICENSE new file mode 100644 index 0000000..ed13d84 --- /dev/null +++ b/monst3r/third_party/RAFT/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2020, princeton-vl +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/monst3r/third_party/RAFT/README.md b/monst3r/third_party/RAFT/README.md new file mode 100644 index 0000000..650275e --- /dev/null +++ b/monst3r/third_party/RAFT/README.md @@ -0,0 +1,80 @@ +# RAFT +This repository contains the source code for our paper: + +[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
+ECCV 2020
+Zachary Teed and Jia Deng
+ + + +## Requirements +The code has been tested with PyTorch 1.6 and Cuda 10.1. +```Shell +conda create --name raft +conda activate raft +conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch +``` + +## Demos +Pretrained models can be downloaded by running +```Shell +./download_models.sh +``` +or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) + +You can demo a trained model on a sequence of frames +```Shell +python demo.py --model=models/raft-things.pth --path=demo-frames +``` + +## Required Data +To evaluate/train RAFT, you will need to download the required datasets. +* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) +* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) +* [Sintel](http://sintel.is.tue.mpg.de/) +* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) +* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) + + +By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder + +```Shell +├── datasets + ├── Sintel + ├── test + ├── training + ├── KITTI + ├── testing + ├── training + ├── devkit + ├── FlyingChairs_release + ├── data + ├── FlyingThings3D + ├── frames_cleanpass + ├── frames_finalpass + ├── optical_flow +``` + +## Evaluation +You can evaluate a trained model using `evaluate.py` +```Shell +python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision +``` + +## Training +We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard +```Shell +./train_standard.sh +``` + +If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) +```Shell +./train_mixed.sh +``` + +## (Optional) Efficent Implementation +You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension +```Shell +cd alt_cuda_corr && python setup.py install && cd .. +``` +and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. diff --git a/monst3r/third_party/RAFT/alt_cuda_corr/correlation.cpp b/monst3r/third_party/RAFT/alt_cuda_corr/correlation.cpp new file mode 100644 index 0000000..b01584d --- /dev/null +++ b/monst3r/third_party/RAFT/alt_cuda_corr/correlation.cpp @@ -0,0 +1,54 @@ +#include +#include + +// CUDA forward declarations +std::vector corr_cuda_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius); + +std::vector corr_cuda_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius); + +// C++ interface +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector corr_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius) { + CHECK_INPUT(fmap1); + CHECK_INPUT(fmap2); + CHECK_INPUT(coords); + + return corr_cuda_forward(fmap1, fmap2, coords, radius); +} + + +std::vector corr_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius) { + CHECK_INPUT(fmap1); + CHECK_INPUT(fmap2); + CHECK_INPUT(coords); + CHECK_INPUT(corr_grad); + + return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &corr_forward, "CORR forward"); + m.def("backward", &corr_backward, "CORR backward"); +} \ No newline at end of file diff --git a/monst3r/third_party/RAFT/alt_cuda_corr/correlation_kernel.cu b/monst3r/third_party/RAFT/alt_cuda_corr/correlation_kernel.cu new file mode 100644 index 0000000..145e580 --- /dev/null +++ b/monst3r/third_party/RAFT/alt_cuda_corr/correlation_kernel.cu @@ -0,0 +1,324 @@ +#include +#include +#include +#include + + +#define BLOCK_H 4 +#define BLOCK_W 8 +#define BLOCK_HW BLOCK_H * BLOCK_W +#define CHANNEL_STRIDE 32 + + +__forceinline__ __device__ +bool within_bounds(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +template +__global__ void corr_forward_kernel( + const torch::PackedTensorAccessor32 fmap1, + const torch::PackedTensorAccessor32 fmap2, + const torch::PackedTensorAccessor32 coords, + torch::PackedTensorAccessor32 corr, + int r) +{ + const int b = blockIdx.x; + const int h0 = blockIdx.y * blockDim.x; + const int w0 = blockIdx.z * blockDim.y; + const int tid = threadIdx.x * blockDim.y + threadIdx.y; + + const int H1 = fmap1.size(1); + const int W1 = fmap1.size(2); + const int H2 = fmap2.size(1); + const int W2 = fmap2.size(2); + const int N = coords.size(1); + const int C = fmap1.size(3); + + __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t x2s[BLOCK_HW]; + __shared__ scalar_t y2s[BLOCK_HW]; + + for (int c=0; c(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + auto fptr = fmap2[b][h2][w2]; + if (within_bounds(h2, w2, H2, W2)) + f2[c2][k1] = fptr[c+c2]; + else + f2[c2][k1] = 0.0; + } + + __syncthreads(); + + scalar_t s = 0.0; + for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_nw) += nw; + + if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_ne) += ne; + + if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_sw) += sw; + + if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_se) += se; + } + } + } + } +} + + +template +__global__ void corr_backward_kernel( + const torch::PackedTensorAccessor32 fmap1, + const torch::PackedTensorAccessor32 fmap2, + const torch::PackedTensorAccessor32 coords, + const torch::PackedTensorAccessor32 corr_grad, + torch::PackedTensorAccessor32 fmap1_grad, + torch::PackedTensorAccessor32 fmap2_grad, + torch::PackedTensorAccessor32 coords_grad, + int r) +{ + + const int b = blockIdx.x; + const int h0 = blockIdx.y * blockDim.x; + const int w0 = blockIdx.z * blockDim.y; + const int tid = threadIdx.x * blockDim.y + threadIdx.y; + + const int H1 = fmap1.size(1); + const int W1 = fmap1.size(2); + const int H2 = fmap2.size(1); + const int W2 = fmap2.size(2); + const int N = coords.size(1); + const int C = fmap1.size(3); + + __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; + + __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; + + __shared__ scalar_t x2s[BLOCK_HW]; + __shared__ scalar_t y2s[BLOCK_HW]; + + for (int c=0; c(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + auto fptr = fmap2[b][h2][w2]; + if (within_bounds(h2, w2, H2, W2)) + f2[c2][k1] = fptr[c+c2]; + else + f2[c2][k1] = 0.0; + + f2_grad[c2][k1] = 0.0; + } + + __syncthreads(); + + const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; + scalar_t g = 0.0; + + int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); + int ix_ne = H1*W1*((iy-1) + rd*ix); + int ix_sw = H1*W1*(iy + rd*(ix-1)); + int ix_se = H1*W1*(iy + rd*ix); + + if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_nw) * dy * dx; + + if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_ne) * dy * (1-dx); + + if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_sw) * (1-dy) * dx; + + if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); + + for (int k=0; k(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; + if (within_bounds(h2, w2, H2, W2)) + atomicAdd(fptr+c+c2, f2_grad[c2][k1]); + } + } + } + } + __syncthreads(); + + + for (int k=0; k corr_cuda_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius) +{ + const auto B = coords.size(0); + const auto N = coords.size(1); + const auto H = coords.size(2); + const auto W = coords.size(3); + + const auto rd = 2 * radius + 1; + auto opts = fmap1.options(); + auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); + + const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); + const dim3 threads(BLOCK_H, BLOCK_W); + + corr_forward_kernel<<>>( + fmap1.packed_accessor32(), + fmap2.packed_accessor32(), + coords.packed_accessor32(), + corr.packed_accessor32(), + radius); + + return {corr}; +} + +std::vector corr_cuda_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius) +{ + const auto B = coords.size(0); + const auto N = coords.size(1); + + const auto H1 = fmap1.size(1); + const auto W1 = fmap1.size(2); + const auto H2 = fmap2.size(1); + const auto W2 = fmap2.size(2); + const auto C = fmap1.size(3); + + auto opts = fmap1.options(); + auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); + auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); + auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); + + const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); + const dim3 threads(BLOCK_H, BLOCK_W); + + + corr_backward_kernel<<>>( + fmap1.packed_accessor32(), + fmap2.packed_accessor32(), + coords.packed_accessor32(), + corr_grad.packed_accessor32(), + fmap1_grad.packed_accessor32(), + fmap2_grad.packed_accessor32(), + coords_grad.packed_accessor32(), + radius); + + return {fmap1_grad, fmap2_grad, coords_grad}; +} \ No newline at end of file diff --git a/monst3r/third_party/RAFT/alt_cuda_corr/setup.py b/monst3r/third_party/RAFT/alt_cuda_corr/setup.py new file mode 100644 index 0000000..c0207ff --- /dev/null +++ b/monst3r/third_party/RAFT/alt_cuda_corr/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='correlation', + ext_modules=[ + CUDAExtension('alt_cuda_corr', + sources=['correlation.cpp', 'correlation_kernel.cu'], + extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), + ], + cmdclass={ + 'build_ext': BuildExtension + }) + diff --git a/monst3r/third_party/RAFT/chairs_split.txt b/monst3r/third_party/RAFT/chairs_split.txt new file mode 100644 index 0000000..6ae8f0b --- /dev/null +++ b/monst3r/third_party/RAFT/chairs_split.txt @@ -0,0 +1,22872 @@ +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 \ No newline at end of file diff --git a/monst3r/third_party/RAFT/core/__init__.py b/monst3r/third_party/RAFT/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/monst3r/third_party/RAFT/core/configs/congif_spring_M.json b/monst3r/third_party/RAFT/core/configs/congif_spring_M.json new file mode 100644 index 0000000..7e9c8a4 --- /dev/null +++ b/monst3r/third_party/RAFT/core/configs/congif_spring_M.json @@ -0,0 +1,30 @@ +{ + "name": "spring-M", + "dataset": "spring", + "gpus": [0, 1, 2, 3, 4, 5, 6, 7], + + "use_var": true, + "var_min": 0, + "var_max": 10, + "pretrain": "resnet34", + "initial_dim": 64, + "block_dims": [64, 128, 256], + "radius": 4, + "dim": 128, + "num_blocks": 2, + "iters": 4, + + "image_size": [540, 960], + "scale": -1, + "batch_size": 32, + "epsilon": 1e-8, + "lr": 4e-4, + "wdecay": 1e-5, + "dropout": 0, + "clip": 1.0, + "gamma": 0.85, + "num_steps": 120000, + + "restore_ckpt": null, + "coarse_config": null +} \ No newline at end of file diff --git a/monst3r/third_party/RAFT/core/corr.py b/monst3r/third_party/RAFT/core/corr.py new file mode 100644 index 0000000..c977add --- /dev/null +++ b/monst3r/third_party/RAFT/core/corr.py @@ -0,0 +1,142 @@ +import torch +import torch.nn.functional as F +from utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + +class CorrBlock2: + def __init__(self, fmap1, fmap2, args): + self.num_levels = args.corr_levels + self.radius = args.corr_radius + self.args = args + self.corr_pyramid = [] + # all pairs correlation + for i in range(self.num_levels): + corr = CorrBlock2.corr(fmap1, fmap2, 1) + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='bilinear', align_corners=False) + self.corr_pyramid.append(corr) + + def __call__(self, coords, dilation=None): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + if dilation is None: + dilation = torch.ones(batch, 1, h1, w1, device=coords.device) + + # print(dilation.max(), dilation.mean(), dilation.min()) + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + device = coords.device + dx = torch.linspace(-r, r, 2*r+1, device=device) + dy = torch.linspace(-r, r, 2*r+1, device=device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1) + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + coords_lvl = centroid_lvl + delta_lvl + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + out = out.permute(0, 3, 1, 2).contiguous().float() + return out + + @staticmethod + def corr(fmap1, fmap2, num_head): + batch, dim, h1, w1 = fmap1.shape + h2, w2 = fmap2.shape[2:] + fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1) + fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2) + corr = fmap1.transpose(2, 3) @ fmap2 + corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5) + return corr / torch.sqrt(torch.tensor(dim).float()) + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/monst3r/third_party/RAFT/core/datasets.py b/monst3r/third_party/RAFT/core/datasets.py new file mode 100644 index 0000000..3411fda --- /dev/null +++ b/monst3r/third_party/RAFT/core/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/monst3r/third_party/RAFT/core/extractor.py b/monst3r/third_party/RAFT/core/extractor.py new file mode 100644 index 0000000..9c1799c --- /dev/null +++ b/monst3r/third_party/RAFT/core/extractor.py @@ -0,0 +1,351 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from layer import conv1x1, conv3x3, BasicBlock + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class ResNetFPN(nn.Module): + """ + ResNet18, output resolution is 1/8. + Each block has 2 layers. + """ + def __init__(self, args, input_dim=3, output_dim=256, ratio=1.0, norm_layer=nn.BatchNorm2d, init_weight=False): + super().__init__() + # Config + block = BasicBlock + block_dims = args.block_dims + initial_dim = args.initial_dim + self.init_weight = init_weight + self.input_dim = input_dim + # Class Variable + self.in_planes = initial_dim + for i in range(len(block_dims)): + block_dims[i] = int(block_dims[i] * ratio) + # Networks + self.conv1 = nn.Conv2d(input_dim, initial_dim, kernel_size=7, stride=2, padding=3) + self.bn1 = norm_layer(initial_dim) + self.relu = nn.ReLU(inplace=True) + if args.pretrain == 'resnet34': + n_block = [3, 4, 6] + elif args.pretrain == 'resnet18': + n_block = [2, 2, 2] + else: + raise NotImplementedError + self.layer1 = self._make_layer(block, block_dims[0], stride=1, norm_layer=norm_layer, num=n_block[0]) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2, norm_layer=norm_layer, num=n_block[1]) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2, norm_layer=norm_layer, num=n_block[2]) # 1/8 + self.final_conv = conv1x1(block_dims[2], output_dim) + self._init_weights(args) + + def _init_weights(self, args): + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + if self.init_weight: + from torchvision.models import resnet18, ResNet18_Weights, resnet34, ResNet34_Weights + if args.pretrain == 'resnet18': + pretrained_dict = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).state_dict() + else: + pretrained_dict = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1).state_dict() + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + if self.input_dim == 6: + for k, v in pretrained_dict.items(): + if k == 'conv1.weight': + pretrained_dict[k] = torch.cat((v, v), dim=1) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict, strict=False) + + + def _make_layer(self, block, dim, stride=1, norm_layer=nn.BatchNorm2d, num=2): + layers = [] + layers.append(block(self.in_planes, dim, stride=stride, norm_layer=norm_layer)) + for i in range(num - 1): + layers.append(block(dim, dim, stride=1, norm_layer=norm_layer)) + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x = self.relu(self.bn1(self.conv1(x))) + for i in range(len(self.layer1)): + x = self.layer1[i](x) + for i in range(len(self.layer2)): + x = self.layer2[i](x) + for i in range(len(self.layer3)): + x = self.layer3[i](x) + # Output + output = self.final_conv(x) + return output \ No newline at end of file diff --git a/monst3r/third_party/RAFT/core/layer.py b/monst3r/third_party/RAFT/core/layer.py new file mode 100644 index 0000000..ecf1e64 --- /dev/null +++ b/monst3r/third_party/RAFT/core/layer.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torch +import math +from torch.nn import Module, Dropout + +### Gradient Clipping and Zeroing Operations ### + +GRAD_CLIP = 0.1 + +class GradClip(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, grad_x): + grad_x = torch.where(torch.isnan(grad_x), torch.zeros_like(grad_x), grad_x) + return grad_x.clamp(min=-0.01, max=0.01) + +class GradientClip(nn.Module): + def __init__(self): + super(GradientClip, self).__init__() + + def forward(self, x): + return GradClip.apply(x) + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + +class ConvNextBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, output_dim, layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * output_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * output_dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + x = self.final(input + x) + return x + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1, norm_layer=nn.BatchNorm2d): + super().__init__() + + # self.sparse = sparse + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = norm_layer(planes) + self.bn2 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.bn3 = norm_layer(planes) + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + self.bn3 + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.relu(self.bn2(self.conv2(y))) + if self.downsample is not None: + x = self.downsample(x) + return self.relu(x+y) \ No newline at end of file diff --git a/monst3r/third_party/RAFT/core/raft.py b/monst3r/third_party/RAFT/core/raft.py new file mode 100644 index 0000000..ae403c6 --- /dev/null +++ b/monst3r/third_party/RAFT/core/raft.py @@ -0,0 +1,291 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from update import BasicUpdateBlock, SmallUpdateBlock, BasicUpdateBlock2 +from extractor import BasicEncoder, SmallEncoder, ResNetFPN +from corr import CorrBlock, AlternateCorrBlock, CorrBlock2 +from utils.utils import bilinear_sampler, coords_grid, upflow8, InputPadder, coords_grid2 +from layer import conv3x3 +import math + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if 'dropout' not in self.args: + self.args.dropout = 0 + + if 'alternate_corr' not in self.args: + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + + def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): + """ Estimate optical flow between pair of frames """ + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions +## +# given depth, warp according to camera params. + +# given flow+depth, warp in 2D + +class RAFT2(nn.Module): + def __init__(self, args): + super(RAFT2, self).__init__() + self.args = args + self.output_dim = args.dim * 2 + + self.args.corr_levels = 4 + self.args.corr_radius = args.radius + self.args.corr_channel = args.corr_levels * (args.radius * 2 + 1) ** 2 + self.cnet = ResNetFPN(args, input_dim=6, output_dim=2 * self.args.dim, norm_layer=nn.BatchNorm2d, init_weight=True) + + # conv for iter 0 results + self.init_conv = conv3x3(2 * args.dim, 2 * args.dim) + self.upsample_weight = nn.Sequential( + # convex combination of 3x3 patches + nn.Conv2d(args.dim, args.dim * 2, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(args.dim * 2, 64 * 9, 1, padding=0) + ) + self.flow_head = nn.Sequential( + # flow(2) + weight(2) + log_b(2) + nn.Conv2d(args.dim, 2 * args.dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(2 * args.dim, 6, 3, padding=1) + ) + if args.iters > 0: + self.fnet = ResNetFPN(args, input_dim=3, output_dim=self.output_dim, norm_layer=nn.BatchNorm2d, init_weight=True) + self.update_block = BasicUpdateBlock2(args, hdim=args.dim, cdim=args.dim) + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords2 - coords1""" + N, C, H, W = img.shape + coords1 = coords_grid(N, H//8, W//8, device=img.device) + coords2 = coords_grid(N, H//8, W//8, device=img.device) + return coords1, coords2 + + def upsample_data(self, flow, info, mask): + """ Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """ + N, C, H, W = info.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + up_info = F.unfold(info, [3, 3], padding=1) + up_info = up_info.view(N, C, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + up_info = torch.sum(mask * up_info, dim=2) + up_info = up_info.permute(0, 1, 4, 2, 5, 3) + + return up_flow.reshape(N, 2, 8*H, 8*W), up_info.reshape(N, C, 8*H, 8*W) + + def forward(self, image1, image2, iters=None, flow_gt=None, test_mode=False): + """ Estimate optical flow between pair of frames """ + N, _, H, W = image1.shape + if iters is None: + iters = self.args.iters + if flow_gt is None: + flow_gt = torch.zeros(N, 2, H, W, device=image1.device) + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + image1 = image1.contiguous() + image2 = image2.contiguous() + flow_predictions = [] + info_predictions = [] + + # padding + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + N, _, H, W = image1.shape + dilation = torch.ones(N, 1, H//8, W//8, device=image1.device) + # run the context network + cnet = self.cnet(torch.cat([image1, image2], dim=1)) + cnet = self.init_conv(cnet) + net, context = torch.split(cnet, [self.args.dim, self.args.dim], dim=1) + + # init flow + flow_update = self.flow_head(net) + weight_update = .25 * self.upsample_weight(net) + flow_8x = flow_update[:, :2] + info_8x = flow_update[:, 2:] + flow_up, info_up = self.upsample_data(flow_8x, info_8x, weight_update) + flow_predictions.append(flow_up) + info_predictions.append(info_up) + + if self.args.iters > 0: + # run the feature network + fmap1_8x = self.fnet(image1) + fmap2_8x = self.fnet(image2) + corr_fn = CorrBlock2(fmap1_8x, fmap2_8x, self.args) + + for itr in range(iters): + N, _, H, W = flow_8x.shape + flow_8x = flow_8x.detach() + coords2 = (coords_grid2(N, H, W, device=image1.device) + flow_8x).detach() + corr = corr_fn(coords2, dilation=dilation) + net = self.update_block(net, context, corr, flow_8x) + flow_update = self.flow_head(net) + weight_update = .25 * self.upsample_weight(net) + flow_8x = flow_8x + flow_update[:, :2] + info_8x = flow_update[:, 2:] + # upsample predictions + flow_up, info_up = self.upsample_data(flow_8x, info_8x, weight_update) + flow_predictions.append(flow_up) + info_predictions.append(info_up) + + for i in range(len(info_predictions)): + flow_predictions[i] = padder.unpad(flow_predictions[i]) + info_predictions[i] = padder.unpad(info_predictions[i]) + + if test_mode == False: + # exlude invalid pixels and extremely large diplacements + nf_predictions = [] + for i in range(len(info_predictions)): + if not self.args.use_var: + var_max = var_min = 0 + else: + var_max = self.args.var_max + var_min = self.args.var_min + + raw_b = info_predictions[i][:, 2:] + log_b = torch.zeros_like(raw_b) + weight = info_predictions[i][:, :2] + # Large b Component + log_b[:, 0] = torch.clamp(raw_b[:, 0], min=0, max=var_max) + # Small b Component + log_b[:, 1] = torch.clamp(raw_b[:, 1], min=var_min, max=0) + # term2: [N, 2, m, H, W] + term2 = ((flow_gt - flow_predictions[i]).abs().unsqueeze(2)) * (torch.exp(-log_b).unsqueeze(1)) + # term1: [N, m, H, W] + term1 = weight - math.log(2) - log_b + nf_loss = torch.logsumexp(weight, dim=1, keepdim=True) - torch.logsumexp(term1.unsqueeze(1) - term2, dim=2) + nf_predictions.append(nf_loss) + + return {'final': flow_predictions[-1], 'flow': flow_predictions, 'info': info_predictions, 'nf': nf_predictions} + else: + return [flow_predictions,flow_predictions[-1]] \ No newline at end of file diff --git a/monst3r/third_party/RAFT/core/update.py b/monst3r/third_party/RAFT/core/update.py new file mode 100644 index 0000000..d9a023d --- /dev/null +++ b/monst3r/third_party/RAFT/core/update.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from layer import ConvNextBlock + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder2(nn.Module): + def __init__(self, args, dim=128): + super(BasicMotionEncoder2, self).__init__() + cor_planes = args.corr_channel + self.convc1 = nn.Conv2d(cor_planes, dim*2, 1, padding=0) + self.convc2 = nn.Conv2d(dim*2, dim+dim//2, 3, padding=1) + self.convf1 = nn.Conv2d(2, dim, 7, padding=3) + self.convf2 = nn.Conv2d(dim, dim//2, 3, padding=1) + self.conv = nn.Conv2d(dim*2, dim-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + +class BasicUpdateBlock2(nn.Module): + def __init__(self, args, hdim=128, cdim=128): + #net: hdim, inp: cdim + super(BasicUpdateBlock2, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder2(args, dim=cdim) + self.refine = [] + for i in range(args.num_blocks): + self.refine.append(ConvNextBlock(2*cdim+hdim, hdim)) + self.refine = nn.ModuleList(self.refine) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + for blk in self.refine: + net = blk(torch.cat([net, inp], dim=1)) + return net diff --git a/monst3r/third_party/RAFT/core/utils/__init__.py b/monst3r/third_party/RAFT/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/monst3r/third_party/RAFT/core/utils/augmentor.py b/monst3r/third_party/RAFT/core/utils/augmentor.py new file mode 100644 index 0000000..e81c4f2 --- /dev/null +++ b/monst3r/third_party/RAFT/core/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/monst3r/third_party/RAFT/core/utils/flow_viz.py b/monst3r/third_party/RAFT/core/utils/flow_viz.py new file mode 100644 index 0000000..dcee65e --- /dev/null +++ b/monst3r/third_party/RAFT/core/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/monst3r/third_party/RAFT/core/utils/frame_utils.py b/monst3r/third_party/RAFT/core/utils/frame_utils.py new file mode 100644 index 0000000..6c49113 --- /dev/null +++ b/monst3r/third_party/RAFT/core/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/monst3r/third_party/RAFT/core/utils/utils.py b/monst3r/third_party/RAFT/core/utils/utils.py new file mode 100644 index 0000000..29b75f9 --- /dev/null +++ b/monst3r/third_party/RAFT/core/utils/utils.py @@ -0,0 +1,86 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + +def coords_grid2(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/monst3r/third_party/RAFT/demo.py b/monst3r/third_party/RAFT/demo.py new file mode 100644 index 0000000..5abc1da --- /dev/null +++ b/monst3r/third_party/RAFT/demo.py @@ -0,0 +1,75 @@ +import sys +sys.path.append('core') + +import argparse +import os +import cv2 +import glob +import numpy as np +import torch +from PIL import Image + +from raft import RAFT +from utils import flow_viz +from utils.utils import InputPadder + + + +DEVICE = 'cuda' + +def load_image(imfile): + img = np.array(Image.open(imfile)).astype(np.uint8) + img = torch.from_numpy(img).permute(2, 0, 1).float() + return img[None].to(DEVICE) + + +def viz(img, flo): + img = img[0].permute(1,2,0).cpu().numpy() + flo = flo[0].permute(1,2,0).cpu().numpy() + + # map flow to rgb image + flo = flow_viz.flow_to_image(flo) + img_flo = np.concatenate([img, flo], axis=0) + + # import matplotlib.pyplot as plt + # plt.imshow(img_flo / 255.0) + # plt.show() + + cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) + cv2.waitKey() + + +def demo(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + with torch.no_grad(): + images = glob.glob(os.path.join(args.path, '*.png')) + \ + glob.glob(os.path.join(args.path, '*.jpg')) + + images = sorted(images) + for imfile1, imfile2 in zip(images[:-1], images[1:]): + image1 = load_image(imfile1) + image2 = load_image(imfile2) + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) + viz(image1, flow_up) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', help="restore checkpoint") + parser.add_argument('--path', help="dataset for evaluation") + parser.add_argument('--small', action='store_true', help='use small model') + parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') + parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') + args = parser.parse_args() + + demo(args) diff --git a/monst3r/third_party/RAFT/download_models.sh b/monst3r/third_party/RAFT/download_models.sh new file mode 100755 index 0000000..7b6ed7e --- /dev/null +++ b/monst3r/third_party/RAFT/download_models.sh @@ -0,0 +1,3 @@ +#!/bin/bash +wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip +unzip models.zip diff --git a/monst3r/third_party/RAFT/evaluate.py b/monst3r/third_party/RAFT/evaluate.py new file mode 100644 index 0000000..431a0f5 --- /dev/null +++ b/monst3r/third_party/RAFT/evaluate.py @@ -0,0 +1,197 @@ +import sys +sys.path.append('core') + +from PIL import Image +import argparse +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt + +import datasets +from utils import flow_viz +from utils import frame_utils + +from raft import RAFT +from utils.utils import InputPadder, forward_interpolate + + +@torch.no_grad() +def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'): + """ Create submission for the Sintel leaderboard """ + model.eval() + for dstype in ['clean', 'final']: + test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) + + flow_prev, sequence_prev = None, None + for test_id in range(len(test_dataset)): + image1, image2, (sequence, frame) = test_dataset[test_id] + if sequence != sequence_prev: + flow_prev = None + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) + flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() + + if warm_start: + flow_prev = forward_interpolate(flow_low[0])[None].cuda() + + output_dir = os.path.join(output_path, dstype, sequence) + output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + frame_utils.writeFlow(output_file, flow) + sequence_prev = sequence + + +@torch.no_grad() +def create_kitti_submission(model, iters=24, output_path='kitti_submission'): + """ Create submission for the Sintel leaderboard """ + model.eval() + test_dataset = datasets.KITTI(split='testing', aug_params=None) + + if not os.path.exists(output_path): + os.makedirs(output_path) + + for test_id in range(len(test_dataset)): + image1, image2, (frame_id, ) = test_dataset[test_id] + padder = InputPadder(image1.shape, mode='kitti') + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + _, flow_pr = model(image1, image2, iters=iters, test_mode=True) + flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() + + output_filename = os.path.join(output_path, frame_id) + frame_utils.writeFlowKITTI(output_filename, flow) + + +@torch.no_grad() +def validate_chairs(model, iters=24): + """ Perform evaluation on the FlyingChairs (test) split """ + model.eval() + epe_list = [] + + val_dataset = datasets.FlyingChairs(split='validation') + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, _ = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + _, flow_pr = model(image1, image2, iters=iters, test_mode=True) + epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + epe = np.mean(np.concatenate(epe_list)) + print("Validation Chairs EPE: %f" % epe) + return {'chairs': epe} + + +@torch.no_grad() +def validate_sintel(model, iters=32): + """ Peform validation using the Sintel (train) split """ + model.eval() + results = {} + for dstype in ['clean', 'final']: + val_dataset = datasets.MpiSintel(split='training', dstype=dstype) + epe_list = [] + + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, _ = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) + flow = padder.unpad(flow_pr[0]).cpu() + + epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all<1) + px3 = np.mean(epe_all<3) + px5 = np.mean(epe_all<5) + + print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) + results[dstype] = np.mean(epe_list) + + return results + + +@torch.no_grad() +def validate_kitti(model, iters=24): + """ Peform validation using the KITTI-2015 (train) split """ + model.eval() + val_dataset = datasets.KITTI(split='training') + + out_list, epe_list = [], [] + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, valid_gt = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape, mode='kitti') + image1, image2 = padder.pad(image1, image2) + + flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) + flow = padder.unpad(flow_pr[0]).cpu() + + epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() + mag = torch.sum(flow_gt**2, dim=0).sqrt() + + epe = epe.view(-1) + mag = mag.view(-1) + val = valid_gt.view(-1) >= 0.5 + + out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() + epe_list.append(epe[val].mean().item()) + out_list.append(out[val].cpu().numpy()) + + epe_list = np.array(epe_list) + out_list = np.concatenate(out_list) + + epe = np.mean(epe_list) + f1 = 100 * np.mean(out_list) + + print("Validation KITTI: %f, %f" % (epe, f1)) + return {'kitti-epe': epe, 'kitti-f1': f1} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', help="restore checkpoint") + parser.add_argument('--dataset', help="dataset for evaluation") + parser.add_argument('--small', action='store_true', help='use small model') + parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') + parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') + args = parser.parse_args() + + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model.cuda() + model.eval() + + # create_sintel_submission(model.module, warm_start=True) + # create_kitti_submission(model.module) + + with torch.no_grad(): + if args.dataset == 'chairs': + validate_chairs(model.module) + + elif args.dataset == 'sintel': + validate_sintel(model.module) + + elif args.dataset == 'kitti': + validate_kitti(model.module) + + diff --git a/monst3r/third_party/RAFT/train.py b/monst3r/third_party/RAFT/train.py new file mode 100644 index 0000000..3075730 --- /dev/null +++ b/monst3r/third_party/RAFT/train.py @@ -0,0 +1,247 @@ +from __future__ import print_function, division +import sys +sys.path.append('core') + +import argparse +import os +import cv2 +import time +import numpy as np +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +from torch.utils.data import DataLoader +from raft import RAFT +import evaluate +import datasets + +from torch.utils.tensorboard import SummaryWriter + +try: + from torch.cuda.amp import GradScaler +except: + # dummy GradScaler for PyTorch < 1.6 + class GradScaler: + def __init__(self): + pass + def scale(self, loss): + return loss + def unscale_(self, optimizer): + pass + def step(self, optimizer): + optimizer.step() + def update(self): + pass + + +# exclude extremly large displacements +MAX_FLOW = 400 +SUM_FREQ = 100 +VAL_FREQ = 5000 + + +def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): + """ Loss function defined over sequence of flow predictions """ + + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < max_flow) + + for i in range(n_predictions): + i_weight = gamma**(n_predictions - i - 1) + i_loss = (flow_preds[i] - flow_gt).abs() + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + + epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() + epe = epe.view(-1)[valid.view(-1)] + + metrics = { + 'epe': epe.mean().item(), + '1px': (epe < 1).float().mean().item(), + '3px': (epe < 3).float().mean().item(), + '5px': (epe < 5).float().mean().item(), + } + + return flow_loss, metrics + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def fetch_optimizer(args, model): + """ Create the optimizer and learning rate scheduler """ + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) + + scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, + pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') + + return optimizer, scheduler + + +class Logger: + def __init__(self, model, scheduler): + self.model = model + self.scheduler = scheduler + self.total_steps = 0 + self.running_loss = {} + self.writer = None + + def _print_training_status(self): + metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] + training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) + metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) + + # print the training status + print(training_str + metrics_str) + + if self.writer is None: + self.writer = SummaryWriter() + + for k in self.running_loss: + self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps) + self.running_loss[k] = 0.0 + + def push(self, metrics): + self.total_steps += 1 + + for key in metrics: + if key not in self.running_loss: + self.running_loss[key] = 0.0 + + self.running_loss[key] += metrics[key] + + if self.total_steps % SUM_FREQ == SUM_FREQ-1: + self._print_training_status() + self.running_loss = {} + + def write_dict(self, results): + if self.writer is None: + self.writer = SummaryWriter() + + for key in results: + self.writer.add_scalar(key, results[key], self.total_steps) + + def close(self): + self.writer.close() + + +def train(args): + + model = nn.DataParallel(RAFT(args), device_ids=args.gpus) + print("Parameter Count: %d" % count_parameters(model)) + + if args.restore_ckpt is not None: + model.load_state_dict(torch.load(args.restore_ckpt), strict=False) + + model.cuda() + model.train() + + if args.stage != 'chairs': + model.module.freeze_bn() + + train_loader = datasets.fetch_dataloader(args) + optimizer, scheduler = fetch_optimizer(args, model) + + total_steps = 0 + scaler = GradScaler(enabled=args.mixed_precision) + logger = Logger(model, scheduler) + + VAL_FREQ = 5000 + add_noise = True + + should_keep_training = True + while should_keep_training: + + for i_batch, data_blob in enumerate(train_loader): + optimizer.zero_grad() + image1, image2, flow, valid = [x.cuda() for x in data_blob] + + if args.add_noise: + stdv = np.random.uniform(0.0, 5.0) + image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0) + image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0) + + flow_predictions = model(image1, image2, iters=args.iters) + + loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + + scaler.step(optimizer) + scheduler.step() + scaler.update() + + logger.push(metrics) + + if total_steps % VAL_FREQ == VAL_FREQ - 1: + PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) + torch.save(model.state_dict(), PATH) + + results = {} + for val_dataset in args.validation: + if val_dataset == 'chairs': + results.update(evaluate.validate_chairs(model.module)) + elif val_dataset == 'sintel': + results.update(evaluate.validate_sintel(model.module)) + elif val_dataset == 'kitti': + results.update(evaluate.validate_kitti(model.module)) + + logger.write_dict(results) + + model.train() + if args.stage != 'chairs': + model.module.freeze_bn() + + total_steps += 1 + + if total_steps > args.num_steps: + should_keep_training = False + break + + logger.close() + PATH = 'checkpoints/%s.pth' % args.name + torch.save(model.state_dict(), PATH) + + return PATH + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--name', default='raft', help="name your experiment") + parser.add_argument('--stage', help="determines which dataset to use for training") + parser.add_argument('--restore_ckpt', help="restore checkpoint") + parser.add_argument('--small', action='store_true', help='use small model') + parser.add_argument('--validation', type=str, nargs='+') + + parser.add_argument('--lr', type=float, default=0.00002) + parser.add_argument('--num_steps', type=int, default=100000) + parser.add_argument('--batch_size', type=int, default=6) + parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) + parser.add_argument('--gpus', type=int, nargs='+', default=[0,1]) + parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') + + parser.add_argument('--iters', type=int, default=12) + parser.add_argument('--wdecay', type=float, default=.00005) + parser.add_argument('--epsilon', type=float, default=1e-8) + parser.add_argument('--clip', type=float, default=1.0) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') + parser.add_argument('--add_noise', action='store_true') + args = parser.parse_args() + + torch.manual_seed(1234) + np.random.seed(1234) + + if not os.path.isdir('checkpoints'): + os.mkdir('checkpoints') + + train(args) \ No newline at end of file diff --git a/monst3r/third_party/RAFT/train_mixed.sh b/monst3r/third_party/RAFT/train_mixed.sh new file mode 100755 index 0000000..d9b979f --- /dev/null +++ b/monst3r/third_party/RAFT/train_mixed.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir -p checkpoints +python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision +python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision +python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision +python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision diff --git a/monst3r/third_party/RAFT/train_standard.sh b/monst3r/third_party/RAFT/train_standard.sh new file mode 100755 index 0000000..7f559b3 --- /dev/null +++ b/monst3r/third_party/RAFT/train_standard.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir -p checkpoints +python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 +python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 +python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 +python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 diff --git a/monst3r/third_party/__init__.py b/monst3r/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/monst3r/third_party/raft.py b/monst3r/third_party/raft.py new file mode 100644 index 0000000..ba98a48 --- /dev/null +++ b/monst3r/third_party/raft.py @@ -0,0 +1,80 @@ + +import sys +import argparse +import torch +import json +from os.path import dirname, join +RAFT_PATH_ROOT = join(dirname(__file__), 'RAFT') +RAFT_PATH_CORE = join(RAFT_PATH_ROOT, 'core') +sys.path.append(RAFT_PATH_CORE) +from raft import RAFT, RAFT2 # nopep8 +from utils.utils import InputPadder # nopep8 + +# %% +# utility functions + +def json_to_args(json_path): + # return a argparse.Namespace object + with open(json_path, 'r') as f: + data = json.load(f) + args = argparse.Namespace() + args_dict = args.__dict__ + for key, value in data.items(): + args_dict[key] = value + return args + +def parse_args(parser): + entry = parser.parse_args(args=[]) + json_path = entry.cfg + args = json_to_args(json_path) + args_dict = args.__dict__ + for index, (key, value) in enumerate(vars(entry).items()): + args_dict[key] = value + return args + +def get_input_padder(shape): + return InputPadder(shape, mode='sintel') + + +def load_RAFT(model_path=None): + if model_path is None or 'M' not in model_path: # RAFT1 + parser = argparse.ArgumentParser() + parser.add_argument('--model', help="restore checkpoint", default=model_path) + parser.add_argument('--path', help="dataset for evaluation") + parser.add_argument('--small', action='store_true', help='use small model') + parser.add_argument('--mixed_precision', + action='store_true', help='use mixed precision') + parser.add_argument('--alternate_corr', action='store_true', + help='use efficient correlation implementation') + + # Set default value for --model if model_path is provided + args = parser.parse_args( + ['--model', model_path if model_path else join(RAFT_PATH_ROOT, 'models', 'raft-sintel.pth'), '--path', './']) + + net = RAFT(args) + else: # RAFT2 + parser = argparse.ArgumentParser() + parser.add_argument('--cfg', help='experiment configure file name', default="third_party/RAFT/core/configs/congif_spring_M.json") + parser.add_argument('--model', help='checkpoint path', default=model_path) + parser.add_argument('--device', help='inference device', type=str, default='cpu') + args = parse_args(parser) + net = RAFT2(args) + + if torch.cuda.is_available(): + state_dict = torch.load(args.model) + else: + state_dict = torch.load(args.model, map_location="cpu") + print('Loaded pretrained RAFT model from', args.model) + new_state_dict = {} + for k in state_dict: + if 'module' in k: + name = k[7:] + else: + name = k + new_state_dict[name] = state_dict[k] + net.load_state_dict(new_state_dict) + return net.eval() + +if __name__ == "__main__": + net = load_RAFT(model_path='third_party/RAFT/models/Tartan-C-T432x960-M.pth') + print(net) \ No newline at end of file diff --git a/monst3r/third_party/sam2/.clang-format b/monst3r/third_party/sam2/.clang-format new file mode 100644 index 0000000..39b1b3d --- /dev/null +++ b/monst3r/third_party/sam2/.clang-format @@ -0,0 +1,85 @@ +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never diff --git a/monst3r/third_party/sam2/.github/workflows/check_fmt.yml b/monst3r/third_party/sam2/.github/workflows/check_fmt.yml new file mode 100644 index 0000000..0a29b88 --- /dev/null +++ b/monst3r/third_party/sam2/.github/workflows/check_fmt.yml @@ -0,0 +1,17 @@ +name: SAM2/fmt +on: + pull_request: + branches: + - main +jobs: + ufmt_check: + runs-on: ubuntu-latest + steps: + - name: Check formatting + uses: omnilib/ufmt@action-v1 + with: + path: sam2 tools + version: "2.0.0b2" + python-version: "3.10" + black-version: "24.2.0" + usort-version: "1.0.2" diff --git a/monst3r/third_party/sam2/.gitignore b/monst3r/third_party/sam2/.gitignore new file mode 100644 index 0000000..50b9875 --- /dev/null +++ b/monst3r/third_party/sam2/.gitignore @@ -0,0 +1,10 @@ +.vscode/ +.DS_Store +__pycache__/ +*-checkpoint.ipynb +.venv +*.egg* +build/* +_C.* +outputs/* +checkpoints/*.pt diff --git a/monst3r/third_party/sam2/.watchmanconfig b/monst3r/third_party/sam2/.watchmanconfig new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/monst3r/third_party/sam2/.watchmanconfig @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/monst3r/third_party/sam2/CODE_OF_CONDUCT.md b/monst3r/third_party/sam2/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..08b500a --- /dev/null +++ b/monst3r/third_party/sam2/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/monst3r/third_party/sam2/CONTRIBUTING.md b/monst3r/third_party/sam2/CONTRIBUTING.md new file mode 100644 index 0000000..ad15049 --- /dev/null +++ b/monst3r/third_party/sam2/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to segment-anything +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to segment-anything, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/INSTALL.md b/monst3r/third_party/sam2/INSTALL.md new file mode 100644 index 0000000..7f32564 --- /dev/null +++ b/monst3r/third_party/sam2/INSTALL.md @@ -0,0 +1,189 @@ +## Installation + +### Requirements + +- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. + * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`. +- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command. +- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. + +Then, install SAM 2 from the root of this repository via +```bash +pip install -e ".[notebooks]" +``` + +Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows: +```bash +# skip the SAM 2 CUDA extension +SAM2_BUILD_CUDA=0 pip install -e ".[notebooks]" +``` +This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases. + +### Building the SAM 2 CUDA extension + +By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.) + +If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases. + +If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows +```bash +pip uninstall -y SAM-2 && \ +rm -f ./sam2/*.so && \ +SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[notebooks]" +``` + +Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`. + +Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime. + +### Common Installation Issues + +Click each issue for its solutions: + +
+ +I got `ImportError: cannot import name '_C' from 'sam2'` + +
+ +This is usually because you haven't run the `pip install -e ".[notebooks]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails. + +In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/sam2/issues/77. +
+ +
+ +I got `MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'` + +
+ +This is usually because you haven't run the `pip install -e .` step above, so `sam2` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via +```bash +export SAM2_REPO_ROOT=/path/to/sam2 # path to this repo +export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}" +``` +to manually add `sam2_configs` into your Python's `sys.path`. + +
+ +
+ +I got `RuntimeError: Error(s) in loading state_dict for SAM2Base` when loading the new SAM 2.1 checkpoints + +
+ +This is likely because you have installed a previous version of this repo, which doesn't have the new modules to support the SAM 2.1 checkpoints yet. Please try the following steps: + +1. pull the latest code from the `main` branch of this repo +2. run `pip uninstall -y SAM-2` to uninstall any previous installations +3. then install the latest repo again using `pip install -e ".[notebooks]"` + +In case the steps above still don't resolve the error, please try running in your Python environment the following +```python +from sam2.modeling import sam2_base + +print(sam2_base.__file__) +``` +and check whether the content in the printed local path of `sam2/modeling/sam2_base.py` matches the latest one in https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam2_base.py (e.g. whether your local file has `no_obj_embed_spatial`) to indentify if you're still using a previous installation. + +
+ +
+ +My installation failed with `CUDA_HOME environment variable is not set` + +
+ +This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via +``` +export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path +``` +and rerun the installation. + +Also, you should make sure +``` +python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)' +``` +print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up. + +If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command: +``` +pip install --no-build-isolation -e . +``` + +
+ +
+ +I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors) + +
+ +This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version. + +In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. + +We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0. +
+ +
+ +I got `CUDA error: no kernel image is available for execution on the device` + +
+ +A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system). + +You can try pulling the latest code from the SAM 2 repo and running the following +``` +export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0` +``` +to manually specify the CUDA capability in the compilation target that matches your GPU. +
+ +
+ +I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors) + +
+ +This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line +```python +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +``` +in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with +```python +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True +``` +to relax the attention kernel setting and use other kernels than Flash Attention. +
+ +
+ +I got `Error compiling objects for extension` + +
+ +You may see error log of: +> unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk. + +This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).
+You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/sam2/blob/main/setup.py).
+After adding the argument, `get_extension()` will look like this: +```python +def get_extensions(): + srcs = ["sam2/csrc/connected_components.cu"] + compile_args = { + "cxx": [], + "nvcc": [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "-allow-unsupported-compiler" # Add this argument + ], + } + ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] + return ext_modules +``` +
diff --git a/monst3r/third_party/sam2/LICENSE b/monst3r/third_party/sam2/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/monst3r/third_party/sam2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/monst3r/third_party/sam2/LICENSE_cctorch b/monst3r/third_party/sam2/LICENSE_cctorch new file mode 100644 index 0000000..23da14a --- /dev/null +++ b/monst3r/third_party/sam2/LICENSE_cctorch @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/monst3r/third_party/sam2/MANIFEST.in b/monst3r/third_party/sam2/MANIFEST.in new file mode 100644 index 0000000..794311f --- /dev/null +++ b/monst3r/third_party/sam2/MANIFEST.in @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +recursive-include sam2 *.yaml #include all config files diff --git a/monst3r/third_party/sam2/README.md b/monst3r/third_party/sam2/README.md new file mode 100644 index 0000000..65654f5 --- /dev/null +++ b/monst3r/third_party/sam2/README.md @@ -0,0 +1,219 @@ +# SAM 2: Segment Anything in Images and Videos + +**[AI at Meta, FAIR](https://ai.meta.com/research/)** + +[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/) + +[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)] + +![SAM 2 architecture](assets/model_diagram.png?raw=true) + +**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains. + +![SA-V dataset](assets/sa_v_dataset.jpg?raw=true) + +## Latest updates + +**09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released** + +- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details. + * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below. +- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started. +- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details. + +## Installation + +SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using: + +```bash +git clone https://github.com/facebookresearch/sam2.git && cd sam2 + +pip install -e . +``` +If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. + +To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by: + +```bash +pip install -e ".[notebooks]" +``` + +Note: +1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`. +2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version. +3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases). + +Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions. + +## Getting Started + +### Download Checkpoints + +First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running: + +```bash +cd checkpoints && \ +./download_ckpts.sh && \ +cd .. +``` + +or individually from: + +- [sam2.1_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt) +- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt) +- [sam2.1_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt) +- [sam2.1_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt) + +(note that these are the improved checkpoints denoted as SAM 2.1; see [Model Description](#model-description) for details.) + +Then SAM 2 can be used in a few lines as follows for image and video prediction. + +### Image prediction + +SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting. + +```python +import torch +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +checkpoint = "./checkpoints/sam2.1_hiera_large.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" +predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases. + +SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images. + +### Video prediction + +For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video. + +```python +import torch +from sam2.build_sam import build_sam2_video_predictor + +checkpoint = "./checkpoints/sam2.1_hiera_large.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" +predictor = build_sam2_video_predictor(model_cfg, checkpoint) + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + +Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos. + +## Load from 🤗 Hugging Face + +Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`). + +For image prediction: + +```python +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor + +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +For video prediction: + +```python +import torch +from sam2.sam2_video_predictor import SAM2VideoPredictor + +predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + +## Model Description + +### SAM 2.1 checkpoints + +The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024. +| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | +| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | +| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 47.2 | 76.5 | 71.8 | 77.3 | +| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 76.6 | 73.5 | 78.3 | +| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 78.2 | 73.7 | 78.2 | +| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 79.5 | 74.6 | 80.6 | + +### SAM 2 checkpoints + +The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows: + +| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | +| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | +| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 | +| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 | +| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 | +| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 | + +\* Compile the model by setting `compile_image_encoder: True` in the config. + +## Segment Anything Video Dataset + +See [sav_dataset/README.md](sav_dataset/README.md) for details. + +## Training SAM 2 + +You can train or fine-tune SAM 2 on custom datasets of images, videos, or both. Please check the training [README](training/README.md) on how to get started. + +## Web demo for SAM 2 + +We have released the frontend + backend code for the SAM 2 web demo (a locally deployable version similar to https://sam2.metademolab.com/demo). Please see the web demo [README](demo/README.md) for details. + +## License + +The SAM 2 model checkpoints, SAM 2 demo code (front-end and back-end), and SAM 2 training code are licensed under [Apache 2.0](./LICENSE), however the [Inter Font](https://github.com/rsms/inter?tab=OFL-1.1-1-ov-file) and [Noto Color Emoji](https://github.com/googlefonts/noto-emoji) used in the SAM 2 demo code are made available under the [SIL Open Font License, version 1.1](https://openfontlicense.org/open-font-license-official-text/). + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Contributors + +The SAM 2 project was made possible with the help of many contributors (alphabetical): + +Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang. + +Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions. + +## Citing SAM 2 + +If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry. + +```bibtex +@article{ravi2024sam2, + title={SAM 2: Segment Anything in Images and Videos}, + author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, + journal={arXiv preprint arXiv:2408.00714}, + url={https://arxiv.org/abs/2408.00714}, + year={2024} +} +``` diff --git a/monst3r/third_party/sam2/backend.Dockerfile b/monst3r/third_party/sam2/backend.Dockerfile new file mode 100644 index 0000000..adec61d --- /dev/null +++ b/monst3r/third_party/sam2/backend.Dockerfile @@ -0,0 +1,64 @@ +ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime +ARG MODEL_SIZE=base_plus + +FROM ${BASE_IMAGE} + +# Gunicorn environment variables +ENV GUNICORN_WORKERS=1 +ENV GUNICORN_THREADS=2 +ENV GUNICORN_PORT=5000 + +# SAM 2 environment variables +ENV APP_ROOT=/opt/sam2 +ENV PYTHONUNBUFFERED=1 +ENV SAM2_BUILD_CUDA=0 +ENV MODEL_SIZE=${MODEL_SIZE} + +# Install system requirements +RUN apt-get update && apt-get install -y --no-install-recommends \ + ffmpeg \ + libavutil-dev \ + libavcodec-dev \ + libavformat-dev \ + libswscale-dev \ + pkg-config \ + build-essential \ + libffi-dev + +COPY setup.py . +COPY README.md . + +RUN pip install --upgrade pip setuptools +RUN pip install -e ".[interactive-demo]" + +# https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707 +RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg + +# Make app directory. This directory will host all files required for the +# backend and SAM 2 inference files. +RUN mkdir ${APP_ROOT} + +# Copy backend server files +COPY demo/backend/server ${APP_ROOT}/server + +# Copy SAM 2 inference files +COPY sam2 ${APP_ROOT}/server/sam2 + +# Download SAM 2.1 checkpoints +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt + +WORKDIR ${APP_ROOT}/server + +# https://pythonspeed.com/articles/gunicorn-in-docker/ +CMD gunicorn --worker-tmp-dir /dev/shm \ + --worker-class gthread app:app \ + --log-level info \ + --access-logfile /dev/stdout \ + --log-file /dev/stderr \ + --workers ${GUNICORN_WORKERS} \ + --threads ${GUNICORN_THREADS} \ + --bind 0.0.0.0:${GUNICORN_PORT} \ + --timeout 60 diff --git a/monst3r/third_party/sam2/docker-compose.yaml b/monst3r/third_party/sam2/docker-compose.yaml new file mode 100644 index 0000000..7a5395a --- /dev/null +++ b/monst3r/third_party/sam2/docker-compose.yaml @@ -0,0 +1,42 @@ +services: + frontend: + image: sam2/frontend + build: + context: ./demo/frontend + dockerfile: frontend.Dockerfile + ports: + - 7262:80 + + backend: + image: sam2/backend + build: + context: . + dockerfile: backend.Dockerfile + ports: + - 7263:5000 + volumes: + - ./demo/data/:/data/:rw + environment: + - SERVER_ENVIRONMENT=DEV + - GUNICORN_WORKERS=1 + # Inference API needs to have at least 2 threads to handle an incoming + # parallel cancel propagation request + - GUNICORN_THREADS=2 + - GUNICORN_PORT=5000 + - API_URL=http://localhost:7263 + - DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 + # # ffmpeg/video encode settings + - FFMPEG_NUM_THREADS=1 + - VIDEO_ENCODE_CODEC=libx264 + - VIDEO_ENCODE_CRF=23 + - VIDEO_ENCODE_FPS=24 + - VIDEO_ENCODE_MAX_WIDTH=1280 + - VIDEO_ENCODE_MAX_HEIGHT=720 + - VIDEO_ENCODE_VERBOSE=False + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/monst3r/third_party/sam2/pyproject.toml b/monst3r/third_party/sam2/pyproject.toml new file mode 100644 index 0000000..f7e8652 --- /dev/null +++ b/monst3r/third_party/sam2/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=61.0", + "torch>=2.3.1", + ] +build-backend = "setuptools.build_meta" diff --git a/monst3r/third_party/sam2/sam2/__init__.py b/monst3r/third_party/sam2/sam2/__init__.py new file mode 100644 index 0000000..0712dd0 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("sam2", version_base="1.2") diff --git a/monst3r/third_party/sam2/sam2/automatic_mask_generator.py b/monst3r/third_party/sam2/sam2/automatic_mask_generator.py new file mode 100644 index 0000000..065e469 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/automatic_mask_generator.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import ( + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + MaskData, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + **kwargs, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2AutomaticMaskGenerator): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + points = torch.as_tensor( + points, dtype=torch.float32, device=self.predictor.device + ) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/monst3r/third_party/sam2/sam2/build_sam.py b/monst3r/third_party/sam2/sam2/build_sam.py new file mode 100644 index 0000000..7cfc451 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/build_sam.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + +import sam2 + +# Check if the user is running Python from the parent directory of the sam2 repo +# (i.e. the directory where this repo is cloned into) -- this is not supported since +# it could shadow the sam2 package and cause issues. +if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): + # If the user has "sam2/sam2" in their path, they are likey importing the repo itself + # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). + # This typically happens because the user is running Python from the parent directory + # that contains the sam2 repo they cloned. + raise RuntimeError( + "You're likely running Python from the parent directory of the sam2 repository " + "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " + "This is not supported since the `sam2` Python package could be shadowed by the " + "repository name (the repository is also named `sam2` and contains the Python package " + "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " + "rather than its parent dir, or from your home directory) after installing SAM 2." + ) + + +HF_MODEL_ID_TO_FILENAMES = { + "facebook/sam2-hiera-tiny": ( + "configs/sam2/sam2_hiera_t.yaml", + "sam2_hiera_tiny.pt", + ), + "facebook/sam2-hiera-small": ( + "configs/sam2/sam2_hiera_s.yaml", + "sam2_hiera_small.pt", + ), + "facebook/sam2-hiera-base-plus": ( + "configs/sam2/sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ( + "configs/sam2/sam2_hiera_l.yaml", + "sam2_hiera_large.pt", + ), + "facebook/sam2.1-hiera-tiny": ( + "configs/sam2.1/sam2.1_hiera_t.yaml", + "sam2.1_hiera_tiny.pt", + ), + "facebook/sam2.1-hiera-small": ( + "configs/sam2.1/sam2.1_hiera_s.yaml", + "sam2.1_hiera_small.pt", + ), + "facebook/sam2.1-hiera-base-plus": ( + "configs/sam2.1/sam2.1_hiera_b+.yaml", + "sam2.1_hiera_base_plus.pt", + ), + "facebook/sam2.1-hiera-large": ( + "configs/sam2.1/sam2.1_hiera_l.yaml", + "sam2.1_hiera_large.pt", + ), +} + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def _hf_download(model_id): + from huggingface_hub import hf_hub_download + + config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return config_name, ckpt_path + + +def build_sam2_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml new file mode 100644 index 0000000..cbee3cf --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml new file mode 100644 index 0000000..33c9097 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml @@ -0,0 +1,120 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml new file mode 100644 index 0000000..8e803df --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -0,0 +1,119 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml new file mode 100644 index 0000000..983c2ea --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -0,0 +1,121 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/monst3r/third_party/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml new file mode 100644 index 0000000..2046791 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml @@ -0,0 +1,339 @@ +# @package _global_ + +scratch: + resolution: 1024 + train_batch_size: 1 + num_train_workers: 10 + num_frames: 8 + max_num_objects: 3 + base_lr: 5.0e-6 + vision_lr: 3.0e-06 + phases_per_epoch: 1 + num_epochs: 40 + +dataset: + # PATHS to Dataset + img_folder: null # PATH to MOSE JPEGImages folder + gt_folder: null # PATH to MOSE Annotations folder + file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training + multiplier: 2 + +# Video transforms +vos: + train_transforms: + - _target_: training.dataset.transforms.ComposeAPI + transforms: + - _target_: training.dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: training.dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: training.dataset.transforms.RandomResizeAPI + sizes: ${scratch.resolution} + square: true + consistent_transform: True + - _target_: training.dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: training.dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: training.dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: training.dataset.transforms.ToTensorAPI + - _target_: training.dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +trainer: + _target_: training.trainer.Trainer + mode: train_only + max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} + accelerator: cuda + seed_value: 123 + + model: + _target_: training.model.sam2.SAM2Train + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + drop_path_rate: 0.1 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: ${scratch.resolution} + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # compile_image_encoder: False + + ####### Training specific params ####### + # box/point input and corrections + prob_to_use_pt_input_for_train: 0.5 + prob_to_use_pt_input_for_eval: 0.0 + prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points + prob_to_use_box_input_for_eval: 0.0 + prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors + num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame) + num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame + rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2 + add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame) + # maximum 2 initial conditioning frames + num_init_cond_frames_for_train: 2 + rand_init_cond_frames_for_train: True # random 1~2 + num_correction_pt_per_frame: 7 + use_act_ckpt_iterative_pt_sampling: false + + + + num_init_cond_frames_for_eval: 1 # only mask on the first frame + forward_backbone_per_frame_for_eval: True + + + data: + train: + _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset + phases_per_epoch: ${scratch.phases_per_epoch} + batch_sizes: + - ${scratch.train_batch_size} + + datasets: + - _target_: training.dataset.utils.RepeatFactorWrapper + dataset: + _target_: training.dataset.utils.ConcatDataset + datasets: + - _target_: training.dataset.vos_dataset.VOSDataset + transforms: ${vos.train_transforms} + training: true + video_dataset: + _target_: training.dataset.vos_raw_dataset.PNGRawDataset + img_folder: ${dataset.img_folder} + gt_folder: ${dataset.gt_folder} + file_list_txt: ${dataset.file_list_txt} + sampler: + _target_: training.dataset.vos_sampler.RandomUniformSampler + num_frames: ${scratch.num_frames} + max_num_objects: ${scratch.max_num_objects} + multiplier: ${dataset.multiplier} + shuffle: True + num_workers: ${scratch.num_train_workers} + pin_memory: True + drop_last: True + collate_fn: + _target_: training.utils.data_utils.collate_fn + _partial_: true + dict_key: all + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + optimizer: + _target_: torch.optim.AdamW + + gradient_clip: + _target_: training.optimizer.GradientClipper + max_norm: 0.1 + norm_type: 2 + + param_group_modifiers: + - _target_: training.optimizer.layer_decay_param_modifier + _partial_: True + layer_decay_value: 0.9 + apply_to: 'image_encoder.trunk' + overrides: + - pattern: '*pos_embed*' + value: 1.0 + + options: + lr: + - scheduler: + _target_: fvcore.common.param_scheduler.CosineParamScheduler + start_value: ${scratch.base_lr} + end_value: ${divide:${scratch.base_lr},10} + - scheduler: + _target_: fvcore.common.param_scheduler.CosineParamScheduler + start_value: ${scratch.vision_lr} + end_value: ${divide:${scratch.vision_lr},10} + param_names: + - 'image_encoder.*' + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.1 + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.0 + param_names: + - '*bias*' + module_cls_names: ['torch.nn.LayerNorm'] + + loss: + all: + _target_: training.loss_fns.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + + distributed: + backend: nccl + find_unused_parameters: True + + logging: + tensorboard_writer: + _target_: training.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + log_dir: ${launcher.experiment_log_dir}/logs + log_freq: 10 + + # initialize from a SAM 2 checkpoint + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + model_weight_initializer: + _partial_: True + _target_: training.utils.checkpoint_utils.load_state_dict_into_model + strict: True + ignore_unexpected_keys: null + ignore_missing_keys: null + + state_dict: + _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels + checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint + ckpt_state_dict_keys: ['model'] + +launcher: + num_nodes: 1 + gpus_per_node: 8 + experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} + +# SLURM args if running on a cluster +submitit: + partition: null + account: null + qos: null + cpus_per_task: 10 + use_cluster: false + timeout_hour: 24 + name: null + port_range: [10000, 65000] + diff --git a/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000..58f3eb8 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_l.yaml b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000..918667f --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_s.yaml b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000..26e5d4d --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_t.yaml b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000..a62c903 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/configs/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/csrc/connected_components.cu b/monst3r/third_party/sam2/sam2/csrc/connected_components.cu new file mode 100644 index 0000000..ced21eb --- /dev/null +++ b/monst3r/third_party/sam2/sam2/csrc/connected_components.cu @@ -0,0 +1,289 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. + +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// adapted from https://github.com/zsef123/Connected_components_PyTorch +// with license found in the LICENSE_cctorch file in the root directory. +#include +#include +#include +#include +#include +#include + +// 2d +#define BLOCK_ROWS 16 +#define BLOCK_COLS 16 + +namespace cc2d { + +template +__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { + return (bitmap >> pos) & 1; +} + +__device__ int32_t find(const int32_t* s_buf, int32_t n) { + while (s_buf[n] != n) + n = s_buf[n]; + return n; +} + +__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { + const int32_t id = n; + while (s_buf[n] != n) { + n = s_buf[n]; + s_buf[id] = n; + } + return n; +} + +__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { + bool done; + do { + a = find(s_buf, a); + b = find(s_buf, b); + + if (a < b) { + int32_t old = atomicMin(s_buf + b, a); + done = (old == b); + b = old; + } else if (b < a) { + int32_t old = atomicMin(s_buf + a, b); + done = (old == a); + a = old; + } else + done = true; + + } while (!done); +} + +__global__ void +init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + label[idx] = idx; +} + +__global__ void +merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + uint32_t P = 0; + + if (img[idx]) + P |= 0x777; + if (row + 1 < H && img[idx + W]) + P |= 0x777 << 4; + if (col + 1 < W && img[idx + 1]) + P |= 0x777 << 1; + + if (col == 0) + P &= 0xEEEE; + if (col + 1 >= W) + P &= 0x3333; + else if (col + 2 >= W) + P &= 0x7777; + + if (row == 0) + P &= 0xFFF0; + if (row + 1 >= H) + P &= 0xFF; + + if (P > 0) { + // If need check about top-left pixel(if flag the first bit) and hit the + // top-left pixel + if (hasBit(P, 0) && img[idx - W - 1]) { + union_(label, idx, idx - 2 * W - 2); // top left block + } + + if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) + union_(label, idx, idx - 2 * W); // top bottom block + + if (hasBit(P, 3) && img[idx + 2 - W]) + union_(label, idx, idx - 2 * W + 2); // top right block + + if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) + union_(label, idx, idx - 2); // just left block + } +} + +__global__ void compression(int32_t* label, const int32_t W, const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + find_n_compress(label, idx); +} + +__global__ void final_labeling( + const uint8_t* img, + int32_t* label, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx] + 1; + + if (img[idx]) + label[idx] = y; + else + label[idx] = 0; + + if (col + 1 < W) { + if (img[idx + 1]) + label[idx + 1] = y; + else + label[idx + 1] = 0; + + if (row + 1 < H) { + if (img[idx + W + 1]) + label[idx + W + 1] = y; + else + label[idx + W + 1] = 0; + } + } + + if (row + 1 < H) { + if (img[idx + W]) + label[idx + W] = y; + else + label[idx + W] = 0; + } +} + +__global__ void init_counting( + const int32_t* label, + int32_t* count_init, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + atomicAdd(count_init + count_idx, 1); + } +} + +__global__ void final_counting( + const int32_t* label, + const int32_t* count_init, + int32_t* count_final, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + count_final[idx] = count_init[count_idx]; + } else { + count_final[idx] = 0; + } +} + +} // namespace cc2d + +std::vector get_connected_componnets( + const torch::Tensor& inputs) { + AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); + AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM( + inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); + + const uint32_t N = inputs.size(0); + const uint32_t C = inputs.size(1); + const uint32_t H = inputs.size(2); + const uint32_t W = inputs.size(3); + + AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM((H % 2) == 0, "height must be an even number"); + AT_ASSERTM((W % 2) == 0, "width must be an even number"); + + // label must be uint32_t + auto label_options = + torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); + torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); + + dim3 grid = dim3( + ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, + ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); + dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); + dim3 grid_count = + dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); + dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + for (int n = 0; n < N; n++) { + uint32_t offset = n * H * W; + + cc2d::init_labeling<<>>( + labels.data_ptr() + offset, W, H); + cc2d::merge<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + cc2d::compression<<>>( + labels.data_ptr() + offset, W, H); + cc2d::final_labeling<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + + // get the counting of each pixel + cc2d::init_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + W, + H); + cc2d::final_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + counts_final.data_ptr() + offset, + W, + H); + } + + // returned values are [labels, counts] + std::vector outputs; + outputs.push_back(labels); + outputs.push_back(counts_final); + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "get_connected_componnets", + &get_connected_componnets, + "get_connected_componnets"); +} diff --git a/monst3r/third_party/sam2/sam2/modeling/__init__.py b/monst3r/third_party/sam2/sam2/modeling/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/sam2/modeling/backbones/__init__.py b/monst3r/third_party/sam2/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/sam2/modeling/backbones/hieradet.py b/monst3r/third_party/sam2/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000..19ac77b --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr + +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weights_path=None, + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/monst3r/third_party/sam2/sam2/modeling/backbones/image_encoder.py b/monst3r/third_party/sam2/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000..37e9266 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + self.d_model = d_model + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/monst3r/third_party/sam2/sam2/modeling/backbones/utils.py b/monst3r/third_party/sam2/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000..32d55c7 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/monst3r/third_party/sam2/sam2/modeling/memory_attention.py b/monst3r/third_party/sam2/sam2/modeling/memory_attention.py new file mode 100644 index 0000000..0b07f9d --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from sam2.modeling.sam.transformer import RoPEAttention + +from sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/monst3r/third_party/sam2/sam2/modeling/memory_encoder.py b/monst3r/third_party/sam2/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000..f60202d --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/monst3r/third_party/sam2/sam2/modeling/position_encoding.py b/monst3r/third_party/sam2/sam2/modeling/position_encoding.py new file mode 100644 index 0000000..52ac226 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/position_encoding.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/monst3r/third_party/sam2/sam2/modeling/sam/__init__.py b/monst3r/third_party/sam2/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/sam2/modeling/sam/mask_decoder.py b/monst3r/third_party/sam2/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000..9bebc03 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/monst3r/third_party/sam2/sam2/modeling/sam/prompt_encoder.py b/monst3r/third_party/sam2/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000..6b3bbb9 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.position_encoding import PositionEmbeddingRandom + +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings diff --git a/monst3r/third_party/sam2/sam2/modeling/sam/transformer.py b/monst3r/third_party/sam2/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000..b5b6fa2 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/sam/transformer.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from sam2.modeling.sam2_utils import MLP +from sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/monst3r/third_party/sam2/sam2/modeling/sam2_base.py b/monst3r/third_party/sam2/sam2/modeling/sam2_base.py new file mode 100644 index 0000000..a5d243a --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/sam2_base.py @@ -0,0 +1,907 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam.mask_decoder import MaskDecoder +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer +from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = image_encoder.neck.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/monst3r/third_party/sam2/sam2/modeling/sam2_utils.py b/monst3r/third_party/sam2/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000..e16caae --- /dev/null +++ b/monst3r/third_party/sam2/sam2/modeling/sam2_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.utils.misc import mask_to_box + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/monst3r/third_party/sam2/sam2/sam2_hiera_b+.yaml b/monst3r/third_party/sam2/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000..58f3eb8 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/sam2_hiera_l.yaml b/monst3r/third_party/sam2/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000..918667f --- /dev/null +++ b/monst3r/third_party/sam2/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/sam2_hiera_s.yaml b/monst3r/third_party/sam2/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000..26e5d4d --- /dev/null +++ b/monst3r/third_party/sam2/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/sam2_hiera_t.yaml b/monst3r/third_party/sam2/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000..a62c903 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/monst3r/third_party/sam2/sam2/sam2_image_predictor.py b/monst3r/third_party/sam2/sam2/sam2_image_predictor.py new file mode 100644 index 0000000..41ce53a --- /dev/null +++ b/monst3r/third_party/sam2/sam2/sam2_image_predictor.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base + +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + **kwargs, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tuple of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/monst3r/third_party/sam2/sam2/sam2_video_predictor.py b/monst3r/third_party/sam2/sam2/sam2_video_predictor.py new file mode 100644 index 0000000..442968b --- /dev/null +++ b/monst3r/third_party/sam2/sam2/sam2_video_predictor.py @@ -0,0 +1,1174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ + "object_score_logits" + ] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + object_score_logits=consolidated_out["object_score_logits"], + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + "object_score_logits": current_out["object_score_logits"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def clear_all_prompts_in_frame( + self, inference_state, frame_idx, obj_id, need_output=True + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = self._get_obj_num(inference_state) + frame_has_input = False + for obj_idx2 in range(batch_size): + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + if torch.cuda.is_available(): + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + if torch.cuda.is_available(): + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object(self, inference_state, obj_id, strict=False, need_output=True): + """ + Remove an object id from the tracking state. If strict is True, we check whether + the object id actually exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + updated_frames = [] + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"], updated_frames + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.reset_state(inference_state) + return inference_state["obj_ids"], updated_frames + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_prompts_in_frame( + inference_state, frame_idx, obj_id, need_output=False + ) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] + out["maskmem_pos_enc"] = [ + x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] + ] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) + out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] + out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] + out["object_score_logits"] = out["object_score_logits"][ + remain_old_obj_inds + ] + # also update the per-object slices + self._add_output_per_object( + inference_state, frame_idx, out, storage_key + ) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/monst3r/third_party/sam2/sam2/utils/__init__.py b/monst3r/third_party/sam2/sam2/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/sam2/utils/amg.py b/monst3r/third_party/sam2/sam2/utils/amg.py new file mode 100644 index 0000000..9868429 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/monst3r/third_party/sam2/sam2/utils/misc.py b/monst3r/third_party/sam2/sam2/utils/misc.py new file mode 100644 index 0000000..136ebd6 --- /dev/null +++ b/monst3r/third_party/sam2/sam2/utils/misc.py @@ -0,0 +1,380 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + # elif is a list + elif isinstance(video_path, list): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + list_input=True, + ) + elif isinstance(video_path, torch.Tensor): + N, D, H, W = video_path.shape + # use bicubic interpolation to resize the video frames + video_frames = torch.nn.functional.interpolate( + video_path, size=(image_size, image_size), mode="bicubic", align_corners=False + ) + # normalize by mean and std + if not offload_video_to_cpu: + video_frames = video_frames.to(compute_device) + img_mean = torch.tensor(img_mean, dtype=torch.float32).to(compute_device)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32).to(compute_device)[:, None, None] + video_frames -= img_mean + video_frames /= img_std + return video_frames, H, W + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), + list_input=False, +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if list_input: + img_paths = video_path + num_frames = len(img_paths) + else: + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/monst3r/third_party/sam2/sam2/utils/transforms.py b/monst3r/third_party/sam2/sam2/utils/transforms.py new file mode 100644 index 0000000..cc17beb --- /dev/null +++ b/monst3r/third_party/sam2/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/monst3r/third_party/sam2/setup.py b/monst3r/third_party/sam2/setup.py new file mode 100644 index 0000000..c67a949 --- /dev/null +++ b/monst3r/third_party/sam2/setup.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os + +from setuptools import find_packages, setup + +# Package metadata +NAME = "SAM-2" +VERSION = "1.0" +DESCRIPTION = "SAM 2: Segment Anything in Images and Videos" +URL = "https://github.com/facebookresearch/sam2" +AUTHOR = "Meta AI" +AUTHOR_EMAIL = "segment-anything@meta.com" +LICENSE = "Apache 2.0" + +# Read the contents of README file +with open("README.md", "r", encoding="utf-8") as f: + LONG_DESCRIPTION = f.read() + +# Required dependencies +REQUIRED_PACKAGES = [ + "torch>=2.3.1", + "torchvision>=0.18.1", + "numpy>=1.24.4", + "tqdm>=4.66.1", + "hydra-core>=1.3.2", + "iopath>=0.1.10", + "pillow>=9.4.0", +] + +EXTRA_PACKAGES = { + "notebooks": [ + "matplotlib>=3.9.1", + "jupyter>=1.0.0", + "opencv-python>=4.7.0", + "eva-decord>=0.6.1", + ], + "interactive-demo": [ + "Flask>=3.0.3", + "Flask-Cors>=5.0.0", + "av>=13.0.0", + "dataclasses-json>=0.6.7", + "eva-decord>=0.6.1", + "gunicorn>=23.0.0", + "imagesize>=1.4.1", + "pycocotools>=2.0.8", + "strawberry-graphql>=0.243.0", + ], + "dev": [ + "black==24.2.0", + "usort==1.0.2", + "ufmt==2.0.0b2", + "fvcore>=0.1.5.post20221221", + "pandas>=2.2.2", + "scikit-image>=0.24.0", + "tensorboard>=2.17.0", + "pycocotools>=2.0.8", + "tensordict>=0.5.0", + "opencv-python>=4.7.0", + "submitit>=1.5.1", + ], +} + +# By default, we also build the SAM 2 CUDA extension. +# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. +BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" +# By default, we allow SAM 2 installation to proceed even with build errors. +# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. +BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" + +# Catch and skip errors during extension building and print a warning message +# (note that this message only shows up under verbose build mode +# "pip install -v -e ." or "python setup.py build_ext -v") +CUDA_ERROR_MSG = ( + "{}\n\n" + "Failed to build the SAM 2 CUDA extension due to the error above. " + "You can still use SAM 2 and it's OK to ignore the error above, although some " + "post-processing functionality may be limited (which doesn't affect the results in most cases; " + "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n" +) + + +def get_extensions(): + if not BUILD_CUDA: + return [] + + try: + from torch.utils.cpp_extension import CUDAExtension + + srcs = ["sam2/csrc/connected_components.cu"] + compile_args = { + "cxx": [], + "nvcc": [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + } + ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] + except Exception as e: + if BUILD_ALLOW_ERRORS: + print(CUDA_ERROR_MSG.format(e)) + ext_modules = [] + else: + raise e + + return ext_modules + + +try: + from torch.utils.cpp_extension import BuildExtension + + class BuildExtensionIgnoreErrors(BuildExtension): + + def finalize_options(self): + try: + super().finalize_options() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + + def build_extensions(self): + try: + super().build_extensions() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + + def get_ext_filename(self, ext_name): + try: + return super().get_ext_filename(ext_name) + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + return "_C.so" + + cmdclass = { + "build_ext": ( + BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) + if BUILD_ALLOW_ERRORS + else BuildExtension.with_options(no_python_abi_suffix=True) + ) + } +except Exception as e: + cmdclass = {} + if BUILD_ALLOW_ERRORS: + print(CUDA_ERROR_MSG.format(e)) + else: + raise e + + +# Setup configuration +setup( + name=NAME, + version=VERSION, + description=DESCRIPTION, + long_description=LONG_DESCRIPTION, + long_description_content_type="text/markdown", + url=URL, + author=AUTHOR, + author_email=AUTHOR_EMAIL, + license=LICENSE, + packages=find_packages(exclude="notebooks"), + include_package_data=True, + install_requires=REQUIRED_PACKAGES, + extras_require=EXTRA_PACKAGES, + python_requires=">=3.10.0", + ext_modules=get_extensions(), + cmdclass=cmdclass, +) diff --git a/monst3r/third_party/sam2/tools/README.md b/monst3r/third_party/sam2/tools/README.md new file mode 100644 index 0000000..1dd0e8a --- /dev/null +++ b/monst3r/third_party/sam2/tools/README.md @@ -0,0 +1,36 @@ +## SAM 2 toolkits + +This directory provides toolkits for additional SAM 2 use cases. + +### Semi-supervised VOS inference + +The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset. + +After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`. +```bash +python ./tools/vos_inference.py \ + --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ + --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ + --base_video_dir /path-to-davis-2017/JPEGImages/480p \ + --input_mask_dir /path-to-davis-2017/Annotations/480p \ + --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \ + --output_mask_dir ./outputs/davis_2017_pred_pngs +``` +(replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset) + +To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag. +```bash +python ./tools/vos_inference.py \ + --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ + --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ + --base_video_dir /path-to-sav-val/JPEGImages_24fps \ + --input_mask_dir /path-to-sav-val/Annotations_6fps \ + --video_list_file /path-to-sav-val/sav_val.txt \ + --per_obj_png_file \ + --output_mask_dir ./outputs/sav_val_pred_pngs +``` +(replace `/path-to-sav-val` with the path to SA-V val) + +Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above. + +Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**. diff --git a/monst3r/third_party/sam2/tools/vos_inference.py b/monst3r/third_party/sam2/tools/vos_inference.py new file mode 100644 index 0000000..5c40cda --- /dev/null +++ b/monst3r/third_party/sam2/tools/vos_inference.py @@ -0,0 +1,501 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from collections import defaultdict + +import numpy as np +import torch +from PIL import Image +from sam2.build_sam import build_sam2_video_predictor + + +# the PNG palette for DAVIS 2017 dataset +DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0" + + +def load_ann_png(path): + """Load a PNG file as a mask and its palette.""" + mask = Image.open(path) + palette = mask.getpalette() + mask = np.array(mask).astype(np.uint8) + return mask, palette + + +def save_ann_png(path, mask, palette): + """Save a mask as a PNG file with the given palette.""" + assert mask.dtype == np.uint8 + assert mask.ndim == 2 + output_mask = Image.fromarray(mask) + output_mask.putpalette(palette) + output_mask.save(path) + + +def get_per_obj_mask(mask): + """Split a mask into per-object masks.""" + object_ids = np.unique(mask) + object_ids = object_ids[object_ids > 0].tolist() + per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids} + return per_obj_mask + + +def put_per_obj_mask(per_obj_mask, height, width): + """Combine per-object masks into a single mask.""" + mask = np.zeros((height, width), dtype=np.uint8) + object_ids = sorted(per_obj_mask)[::-1] + for object_id in object_ids: + object_mask = per_obj_mask[object_id] + object_mask = object_mask.reshape(height, width) + mask[object_mask] = object_id + return mask + + +def load_masks_from_dir( + input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False +): + """Load masks from a directory as a dict of per-object masks.""" + if not per_obj_png_file: + input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png") + if allow_missing and not os.path.exists(input_mask_path): + return {}, None + input_mask, input_palette = load_ann_png(input_mask_path) + per_obj_input_mask = get_per_obj_mask(input_mask) + else: + per_obj_input_mask = {} + input_palette = None + # each object is a directory in "{object_id:%03d}" format + for object_name in os.listdir(os.path.join(input_mask_dir, video_name)): + object_id = int(object_name) + input_mask_path = os.path.join( + input_mask_dir, video_name, object_name, f"{frame_name}.png" + ) + if allow_missing and not os.path.exists(input_mask_path): + continue + input_mask, input_palette = load_ann_png(input_mask_path) + per_obj_input_mask[object_id] = input_mask > 0 + + return per_obj_input_mask, input_palette + + +def save_masks_to_dir( + output_mask_dir, + video_name, + frame_name, + per_obj_output_mask, + height, + width, + per_obj_png_file, + output_palette, +): + """Save masks to a directory as PNG files.""" + os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) + if not per_obj_png_file: + output_mask = put_per_obj_mask(per_obj_output_mask, height, width) + output_mask_path = os.path.join( + output_mask_dir, video_name, f"{frame_name}.png" + ) + save_ann_png(output_mask_path, output_mask, output_palette) + else: + for object_id, object_mask in per_obj_output_mask.items(): + object_name = f"{object_id:03d}" + os.makedirs( + os.path.join(output_mask_dir, video_name, object_name), + exist_ok=True, + ) + output_mask = object_mask.reshape(height, width).astype(np.uint8) + output_mask_path = os.path.join( + output_mask_dir, video_name, object_name, f"{frame_name}.png" + ) + save_ann_png(output_mask_path, output_mask, output_palette) + + +@torch.inference_mode() +@torch.autocast(device_type="cuda", dtype=torch.bfloat16) +def vos_inference( + predictor, + base_video_dir, + input_mask_dir, + output_mask_dir, + video_name, + score_thresh=0.0, + use_all_masks=False, + per_obj_png_file=False, +): + """Run VOS inference on a single video with the given predictor.""" + # load the video frames and initialize the inference state on this video + video_dir = os.path.join(base_video_dir, video_name) + frame_names = [ + os.path.splitext(p)[0] + for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + inference_state = predictor.init_state( + video_path=video_dir, async_loading_frames=False + ) + height = inference_state["video_height"] + width = inference_state["video_width"] + input_palette = None + + # fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks) + if not use_all_masks: + # use only the first video's ground-truth mask as the input mask + input_frame_inds = [0] + else: + # use all mask files available in the input_mask_dir as the input masks + if not per_obj_png_file: + input_frame_inds = [ + idx + for idx, name in enumerate(frame_names) + if os.path.exists( + os.path.join(input_mask_dir, video_name, f"{name}.png") + ) + ] + else: + input_frame_inds = [ + idx + for object_name in os.listdir(os.path.join(input_mask_dir, video_name)) + for idx, name in enumerate(frame_names) + if os.path.exists( + os.path.join(input_mask_dir, video_name, object_name, f"{name}.png") + ) + ] + # check and make sure we got at least one input frame + if len(input_frame_inds) == 0: + raise RuntimeError( + f"In {video_name=}, got no input masks in {input_mask_dir=}. " + "Please make sure the input masks are available in the correct format." + ) + input_frame_inds = sorted(set(input_frame_inds)) + + # add those input masks to SAM 2 inference state before propagation + object_ids_set = None + for input_frame_idx in input_frame_inds: + try: + per_obj_input_mask, input_palette = load_masks_from_dir( + input_mask_dir=input_mask_dir, + video_name=video_name, + frame_name=frame_names[input_frame_idx], + per_obj_png_file=per_obj_png_file, + ) + except FileNotFoundError as e: + raise RuntimeError( + f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. " + "Please add the `--track_object_appearing_later_in_video` flag " + "for VOS datasets that don't have all objects to track appearing " + "in the first frame (such as LVOS or YouTube-VOS)." + ) from e + # get the list of object ids to track from the first input frame + if object_ids_set is None: + object_ids_set = set(per_obj_input_mask) + for object_id, object_mask in per_obj_input_mask.items(): + # check and make sure no new object ids appear only in later frames + if object_id not in object_ids_set: + raise RuntimeError( + f"In {video_name=}, got a new {object_id=} appearing only in a " + f"later {input_frame_idx=} (but not appearing in the first frame). " + "Please add the `--track_object_appearing_later_in_video` flag " + "for VOS datasets that don't have all objects to track appearing " + "in the first frame (such as LVOS or YouTube-VOS)." + ) + predictor.add_new_mask( + inference_state=inference_state, + frame_idx=input_frame_idx, + obj_id=object_id, + mask=object_mask, + ) + + # check and make sure we have at least one object to track + if object_ids_set is None or len(object_ids_set) == 0: + raise RuntimeError( + f"In {video_name=}, got no object ids on {input_frame_inds=}. " + "Please add the `--track_object_appearing_later_in_video` flag " + "for VOS datasets that don't have all objects to track appearing " + "in the first frame (such as LVOS or YouTube-VOS)." + ) + # run propagation throughout the video and collect the results in a dict + os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) + output_palette = input_palette or DAVIS_PALETTE + video_segments = {} # video_segments contains the per-frame segmentation results + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( + inference_state + ): + per_obj_output_mask = { + out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + video_segments[out_frame_idx] = per_obj_output_mask + + # write the output masks as palette PNG files to output_mask_dir + for out_frame_idx, per_obj_output_mask in video_segments.items(): + save_masks_to_dir( + output_mask_dir=output_mask_dir, + video_name=video_name, + frame_name=frame_names[out_frame_idx], + per_obj_output_mask=per_obj_output_mask, + height=height, + width=width, + per_obj_png_file=per_obj_png_file, + output_palette=output_palette, + ) + + +@torch.inference_mode() +@torch.autocast(device_type="cuda", dtype=torch.bfloat16) +def vos_separate_inference_per_object( + predictor, + base_video_dir, + input_mask_dir, + output_mask_dir, + video_name, + score_thresh=0.0, + use_all_masks=False, + per_obj_png_file=False, +): + """ + Run VOS inference on a single video with the given predictor. + + Unlike `vos_inference`, this function run inference separately for each object + in a video, which could be applied to datasets like LVOS or YouTube-VOS that + don't have all objects to track appearing in the first frame (i.e. some objects + might appear only later in the video). + """ + # load the video frames and initialize the inference state on this video + video_dir = os.path.join(base_video_dir, video_name) + frame_names = [ + os.path.splitext(p)[0] + for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + inference_state = predictor.init_state( + video_path=video_dir, async_loading_frames=False + ) + height = inference_state["video_height"] + width = inference_state["video_width"] + input_palette = None + + # collect all the object ids and their input masks + inputs_per_object = defaultdict(dict) + for idx, name in enumerate(frame_names): + if per_obj_png_file or os.path.exists( + os.path.join(input_mask_dir, video_name, f"{name}.png") + ): + per_obj_input_mask, input_palette = load_masks_from_dir( + input_mask_dir=input_mask_dir, + video_name=video_name, + frame_name=frame_names[idx], + per_obj_png_file=per_obj_png_file, + allow_missing=True, + ) + for object_id, object_mask in per_obj_input_mask.items(): + # skip empty masks + if not np.any(object_mask): + continue + # if `use_all_masks=False`, we only use the first mask for each object + if len(inputs_per_object[object_id]) > 0 and not use_all_masks: + continue + print(f"adding mask from frame {idx} as input for {object_id=}") + inputs_per_object[object_id][idx] = object_mask + + # run inference separately for each object in the video + object_ids = sorted(inputs_per_object) + output_scores_per_object = defaultdict(dict) + for object_id in object_ids: + # add those input masks to SAM 2 inference state before propagation + input_frame_inds = sorted(inputs_per_object[object_id]) + predictor.reset_state(inference_state) + for input_frame_idx in input_frame_inds: + predictor.add_new_mask( + inference_state=inference_state, + frame_idx=input_frame_idx, + obj_id=object_id, + mask=inputs_per_object[object_id][input_frame_idx], + ) + + # run propagation throughout the video and collect the results in a dict + for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video( + inference_state, + start_frame_idx=min(input_frame_inds), + reverse=False, + ): + obj_scores = out_mask_logits.cpu().numpy() + output_scores_per_object[object_id][out_frame_idx] = obj_scores + + # post-processing: consolidate the per-object scores into per-frame masks + os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) + output_palette = input_palette or DAVIS_PALETTE + video_segments = {} # video_segments contains the per-frame segmentation results + for frame_idx in range(len(frame_names)): + scores = torch.full( + size=(len(object_ids), 1, height, width), + fill_value=-1024.0, + dtype=torch.float32, + ) + for i, object_id in enumerate(object_ids): + if frame_idx in output_scores_per_object[object_id]: + scores[i] = torch.from_numpy( + output_scores_per_object[object_id][frame_idx] + ) + + if not per_obj_png_file: + scores = predictor._apply_non_overlapping_constraints(scores) + per_obj_output_mask = { + object_id: (scores[i] > score_thresh).cpu().numpy() + for i, object_id in enumerate(object_ids) + } + video_segments[frame_idx] = per_obj_output_mask + + # write the output masks as palette PNG files to output_mask_dir + for frame_idx, per_obj_output_mask in video_segments.items(): + save_masks_to_dir( + output_mask_dir=output_mask_dir, + video_name=video_name, + frame_name=frame_names[frame_idx], + per_obj_output_mask=per_obj_output_mask, + height=height, + width=width, + per_obj_png_file=per_obj_png_file, + output_palette=output_palette, + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--sam2_cfg", + type=str, + default="configs/sam2.1/sam2.1_hiera_b+.yaml", + help="SAM 2 model configuration file", + ) + parser.add_argument( + "--sam2_checkpoint", + type=str, + default="./checkpoints/sam2.1_hiera_b+.pt", + help="path to the SAM 2 model checkpoint", + ) + parser.add_argument( + "--base_video_dir", + type=str, + required=True, + help="directory containing videos (as JPEG files) to run VOS prediction on", + ) + parser.add_argument( + "--input_mask_dir", + type=str, + required=True, + help="directory containing input masks (as PNG files) of each video", + ) + parser.add_argument( + "--video_list_file", + type=str, + default=None, + help="text file containing the list of video names to run VOS prediction on", + ) + parser.add_argument( + "--output_mask_dir", + type=str, + required=True, + help="directory to save the output masks (as PNG files)", + ) + parser.add_argument( + "--score_thresh", + type=float, + default=0.0, + help="threshold for the output mask logits (default: 0.0)", + ) + parser.add_argument( + "--use_all_masks", + action="store_true", + help="whether to use all available PNG files in input_mask_dir " + "(default without this flag: just the first PNG file as input to the SAM 2 model; " + "usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)", + ) + parser.add_argument( + "--per_obj_png_file", + action="store_true", + help="whether use separate per-object PNG files for input and output masks " + "(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; " + "note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)", + ) + parser.add_argument( + "--apply_postprocessing", + action="store_true", + help="whether to apply postprocessing (e.g. hole-filling) to the output masks " + "(we don't apply such post-processing in the SAM 2 model evaluation)", + ) + parser.add_argument( + "--track_object_appearing_later_in_video", + action="store_true", + help="whether to track objects that appear later in the video (i.e. not on the first frame; " + "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", + ) + args = parser.parse_args() + + # if we use per-object PNG files, they could possibly overlap in inputs and outputs + hydra_overrides_extra = [ + "++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true") + ] + predictor = build_sam2_video_predictor( + config_file=args.sam2_cfg, + ckpt_path=args.sam2_checkpoint, + apply_postprocessing=args.apply_postprocessing, + hydra_overrides_extra=hydra_overrides_extra, + ) + + if args.use_all_masks: + print("using all available masks in input_mask_dir as input to the SAM 2 model") + else: + print( + "using only the first frame's mask in input_mask_dir as input to the SAM 2 model" + ) + # if a video list file is provided, read the video names from the file + # (otherwise, we use all subdirectories in base_video_dir) + if args.video_list_file is not None: + with open(args.video_list_file, "r") as f: + video_names = [v.strip() for v in f.readlines()] + else: + video_names = [ + p + for p in os.listdir(args.base_video_dir) + if os.path.isdir(os.path.join(args.base_video_dir, p)) + ] + print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}") + + for n_video, video_name in enumerate(video_names): + print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}") + if not args.track_object_appearing_later_in_video: + vos_inference( + predictor=predictor, + base_video_dir=args.base_video_dir, + input_mask_dir=args.input_mask_dir, + output_mask_dir=args.output_mask_dir, + video_name=video_name, + score_thresh=args.score_thresh, + use_all_masks=args.use_all_masks, + per_obj_png_file=args.per_obj_png_file, + ) + else: + vos_separate_inference_per_object( + predictor=predictor, + base_video_dir=args.base_video_dir, + input_mask_dir=args.input_mask_dir, + output_mask_dir=args.output_mask_dir, + video_name=video_name, + score_thresh=args.score_thresh, + use_all_masks=args.use_all_masks, + per_obj_png_file=args.per_obj_png_file, + ) + + print( + f"completed VOS prediction on {len(video_names)} videos -- " + f"output masks saved to {args.output_mask_dir}" + ) + + +if __name__ == "__main__": + main() diff --git a/monst3r/third_party/sam2/training/README.md b/monst3r/third_party/sam2/training/README.md new file mode 100644 index 0000000..b0c829d --- /dev/null +++ b/monst3r/third_party/sam2/training/README.md @@ -0,0 +1,116 @@ +# Training Code for SAM 2 + +This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos. +The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both). + +## Structure + +The training code is organized into the following subfolders: + +* `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms. +* `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling). +* `utils`: This folder contains training utils such as loggers and distributed training utils. +* `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training. +* `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training. +* `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers. +* `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop. +* `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h` + +## Getting Started + +To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets. + +#### Requirements: +- We assume training on A100 GPUs with **80 GB** of memory. +- Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download). + +#### Steps to fine-tune on MOSE: +- Install the packages required for training by running `pip install -e ".[dev]"`. +- Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`. + ```yaml + dataset: + # PATHS to Dataset + img_folder: null # PATH to MOSE JPEGImages folder + gt_folder: null # PATH to MOSE Annotations folder + file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training + ``` +- To fine-tune the base model on MOSE using 8 GPUs, run + + ```python + python training/train.py \ + -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ + --use-cluster 0 \ + --num-gpus 8 + ``` + + We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running + + ```python + python training/train.py \ + -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ + --use-cluster 1 \ + --num-gpus 8 \ + --num-nodes 2 + --partition $PARTITION \ + --qos $QOS \ + --account $ACCOUNT + ``` + where partition, qos, and account are optional and depend on your SLURM configuration. + By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows: + + ```yaml + experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} + ``` + The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4. + + + After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)). +## Training on images and videos +The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets: + +```yaml +data: + train: + _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset + phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases + batch_sizes: # List of batch sizes corresponding to each dataset + - ${bs1} # Batch size of dataset 1 + - ${bs2} # Batch size of dataset 2 + datasets: + # SA1B as an example of an image dataset + - _target_: training.dataset.vos_dataset.VOSDataset + training: true + video_dataset: + _target_: training.dataset.vos_raw_dataset.SA1BRawDataset + img_folder: ${path_to_img_folder} + gt_folder: ${path_to_gt_folder} + file_list_txt: ${path_to_train_filelist} # Optional + sampler: + _target_: training.dataset.vos_sampler.RandomUniformSampler + num_frames: 1 + max_num_objects: ${max_num_objects_per_image} + transforms: ${image_transforms} + # SA-V as an example of a video dataset + - _target_: training.dataset.vos_dataset.VOSDataset + training: true + video_dataset: + _target_: training.dataset.vos_raw_dataset.JSONRawDataset + img_folder: ${path_to_img_folder} + gt_folder: ${path_to_gt_folder} + file_list_txt: ${path_to_train_filelist} # Optional + ann_every: 4 + sampler: + _target_: training.dataset.vos_sampler.RandomUniformSampler + num_frames: 8 # Number of frames per video + max_num_objects: ${max_num_objects_per_video} + reverse_time_prob: ${reverse_time_prob} # probability to reverse video + transforms: ${video_transforms} + shuffle: True + num_workers: ${num_train_workers} + pin_memory: True + drop_last: True + collate_fn: + _target_: training.utils.data_utils.collate_fn + _partial_: true + dict_key: all +``` diff --git a/monst3r/third_party/sam2/training/__init__.py b/monst3r/third_party/sam2/training/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/training/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/training/dataset/__init__.py b/monst3r/third_party/sam2/training/dataset/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/training/dataset/sam2_datasets.py b/monst3r/third_party/sam2/training/dataset/sam2_datasets.py new file mode 100644 index 0000000..6deda05 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/sam2_datasets.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import Callable, Iterable, List, Optional, Sequence + +import torch + +from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset + +from torch.utils.data.distributed import DistributedSampler + + +class MixedDataLoader: + def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): + """ + Args: + dataloaders (List[DataLoader]): List of DataLoaders to be mixed. + mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from + + """ + assert len(dataloaders) == mixing_prob.shape[0] + self.dataloaders = dataloaders + self.mixing_prob = mixing_prob + # Iterator state + self._iter_dls = None + self._iter_mixing_prob = None + self.random_generator = torch.Generator() + + def __len__(self): + return sum([len(d) for d in self.dataloaders]) + + def __iter__(self): + # Synchronize dataloader seeds + self.random_generator.manual_seed(42) + self._iter_dls = [iter(loader) for loader in self.dataloaders] + self._iter_mixing_prob = self.mixing_prob.clone() + return self + + def __next__(self): + """ + Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. + """ + if self._iter_dls is None: + raise TypeError(f"{type(self).__name__} object is not an iterator") + + while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob. + dataset_idx = self._iter_mixing_prob.multinomial( + 1, generator=self.random_generator + ).item() + try: + item = next(self._iter_dls[dataset_idx]) + return item + except StopIteration: + # No more iterations for this dataset, set it's mixing probability to zero and try again. + self._iter_mixing_prob[dataset_idx] = 0 + except Exception as e: + # log and raise any other unexpected error. + logging.error(e) + raise e + + # Exhausted all iterators + raise StopIteration + + +class TorchTrainMixedDataset: + def __init__( + self, + datasets: List[Dataset], + batch_sizes: List[int], + num_workers: int, + shuffle: bool, + pin_memory: bool, + drop_last: bool, + collate_fn: Optional[Callable] = None, + worker_init_fn: Optional[Callable] = None, + phases_per_epoch: int = 1, + dataset_prob: Optional[List[float]] = None, + ) -> None: + """ + Args: + datasets (List[Dataset]): List of Datasets to be mixed. + batch_sizes (List[int]): Batch sizes for each dataset in the list. + num_workers (int): Number of workers per dataloader. + shuffle (bool): Whether or not to shuffle data. + pin_memory (bool): If True, use pinned memory when loading tensors from disk. + drop_last (bool): Whether or not to drop the last batch of data. + collate_fn (Callable): Function to merge a list of samples into a mini-batch. + worker_init_fn (Callable): Function to init each dataloader worker. + phases_per_epoch (int): Number of phases per epoch. + dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 + """ + + self.datasets = datasets + self.batch_sizes = batch_sizes + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + self.collate_fn = collate_fn + self.worker_init_fn = worker_init_fn + assert len(self.datasets) > 0 + for dataset in self.datasets: + assert not isinstance(dataset, IterableDataset), "Not supported" + # `RepeatFactorWrapper` requires calling set_epoch first to get its length + self._set_dataset_epoch(dataset, 0) + self.phases_per_epoch = phases_per_epoch + self.chunks = [None] * len(datasets) + if dataset_prob is None: + # If not provided, assign each dataset a probability proportional to its length. + dataset_lens = [ + (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) + for d, bs in zip(datasets, batch_sizes) + ] + total_len = sum(dataset_lens) + dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) + else: + assert len(dataset_prob) == len(datasets) + dataset_prob = torch.tensor(dataset_prob) + + logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") + assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" + self.dataset_prob = dataset_prob + + def _set_dataset_epoch(self, dataset, epoch: int) -> None: + if hasattr(dataset, "epoch"): + dataset.epoch = epoch + if hasattr(dataset, "set_epoch"): + dataset.set_epoch(epoch) + + def get_loader(self, epoch) -> Iterable: + dataloaders = [] + for d_idx, (dataset, batch_size) in enumerate( + zip(self.datasets, self.batch_sizes) + ): + if self.phases_per_epoch > 1: + # Major epoch that looops over entire dataset + # len(main_epoch) == phases_per_epoch * len(epoch) + main_epoch = epoch // self.phases_per_epoch + + # Phase with in the main epoch + local_phase = epoch % self.phases_per_epoch + + # Start of new data-epoch or job is resumed after preemtion. + if local_phase == 0 or self.chunks[d_idx] is None: + # set seed for dataset epoch + # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking. + self._set_dataset_epoch(dataset, main_epoch) + + # Separate random generator for subset sampling + g = torch.Generator() + g.manual_seed(main_epoch) + self.chunks[d_idx] = torch.chunk( + torch.randperm(len(dataset), generator=g), + self.phases_per_epoch, + ) + + dataset = Subset(dataset, self.chunks[d_idx][local_phase]) + else: + self._set_dataset_epoch(dataset, epoch) + + sampler = DistributedSampler(dataset, shuffle=self.shuffle) + sampler.set_epoch(epoch) + + batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) + dataloaders.append( + DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + worker_init_fn=self.worker_init_fn, + ) + ) + return MixedDataLoader(dataloaders, self.dataset_prob) diff --git a/monst3r/third_party/sam2/training/dataset/transforms.py b/monst3r/third_party/sam2/training/dataset/transforms.py new file mode 100644 index 0000000..8e5c651 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/transforms.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transforms and data augmentation for both image + bbox. +""" + +import logging + +import random +from typing import Iterable + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as Fv2 +from PIL import Image as PILImage + +from torchvision.transforms import InterpolationMode + +from training.utils.data_utils import VideoDatapoint + + +def hflip(datapoint, index): + + datapoint.frames[index].data = F.hflip(datapoint.frames[index].data) + for obj in datapoint.frames[index].objects: + if obj.segment is not None: + obj.segment = F.hflip(obj.segment) + + return datapoint + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = max_size * min_original_size / max_original_size + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = int(round(size)) + oh = int(round(size * h / w)) + else: + oh = int(round(size)) + ow = int(round(size * w / h)) + + return (oh, ow) + + +def resize(datapoint, index, size, max_size=None, square=False, v2=False): + # size can be min_size (scalar) or (w, h) tuple + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + cur_size = ( + datapoint.frames[index].data.size()[-2:][::-1] + if v2 + else datapoint.frames[index].data.size + ) + size = get_size(cur_size, size, max_size) + + old_size = ( + datapoint.frames[index].data.size()[-2:][::-1] + if v2 + else datapoint.frames[index].data.size + ) + if v2: + datapoint.frames[index].data = Fv2.resize( + datapoint.frames[index].data, size, antialias=True + ) + else: + datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size) + + new_size = ( + datapoint.frames[index].data.size()[-2:][::-1] + if v2 + else datapoint.frames[index].data.size + ) + + for obj in datapoint.frames[index].objects: + if obj.segment is not None: + obj.segment = F.resize(obj.segment[None, None], size).squeeze() + + h, w = size + datapoint.frames[index].size = (h, w) + return datapoint + + +def pad(datapoint, index, padding, v2=False): + old_h, old_w = datapoint.frames[index].size + h, w = old_h, old_w + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + datapoint.frames[index].data = F.pad( + datapoint.frames[index].data, (0, 0, padding[0], padding[1]) + ) + h += padding[1] + w += padding[0] + else: + # left, top, right, bottom + datapoint.frames[index].data = F.pad( + datapoint.frames[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + h += padding[1] + padding[3] + w += padding[0] + padding[2] + + datapoint.frames[index].size = (h, w) + + for obj in datapoint.frames[index].objects: + if obj.segment is not None: + if v2: + if len(padding) == 2: + obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = Fv2.pad(obj.segment, tuple(padding)) + else: + if len(padding) == 2: + obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = F.pad(obj.segment, tuple(padding)) + return datapoint + + +class RandomHorizontalFlip: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for i in range(len(datapoint.frames)): + datapoint = hflip(datapoint, i) + return datapoint + for i in range(len(datapoint.frames)): + if random.random() < self.p: + datapoint = hflip(datapoint, i) + return datapoint + + +class RandomResizeAPI: + def __init__( + self, sizes, consistent_transform, max_size=None, square=False, v2=False + ): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + self.consistent_transform = consistent_transform + self.v2 = v2 + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + size = random.choice(self.sizes) + for i in range(len(datapoint.frames)): + datapoint = resize( + datapoint, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return datapoint + for i in range(len(datapoint.frames)): + size = random.choice(self.sizes) + datapoint = resize( + datapoint, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return datapoint + + +class ToTensorAPI: + def __init__(self, v2=False): + self.v2 = v2 + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + for img in datapoint.frames: + if self.v2: + img.data = Fv2.to_image_tensor(img.data) + else: + img.data = F.to_tensor(img.data) + return datapoint + + +class NormalizeAPI: + def __init__(self, mean, std, v2=False): + self.mean = mean + self.std = std + self.v2 = v2 + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + for img in datapoint.frames: + if self.v2: + img.data = Fv2.convert_image_dtype(img.data, torch.float32) + img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) + else: + img.data = F.normalize(img.data, mean=self.mean, std=self.std) + + return datapoint + + +class ComposeAPI: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, datapoint, **kwargs): + for t in self.transforms: + datapoint = t(datapoint, **kwargs) + return datapoint + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class RandomGrayscale: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + self.Grayscale = T.Grayscale(num_output_channels=3) + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for img in datapoint.frames: + img.data = self.Grayscale(img.data) + return datapoint + for img in datapoint.frames: + if random.random() < self.p: + img.data = self.Grayscale(img.data) + return datapoint + + +class ColorJitter: + def __init__(self, consistent_transform, brightness, contrast, saturation, hue): + self.consistent_transform = consistent_transform + self.brightness = ( + brightness + if isinstance(brightness, list) + else [max(0, 1 - brightness), 1 + brightness] + ) + self.contrast = ( + contrast + if isinstance(contrast, list) + else [max(0, 1 - contrast), 1 + contrast] + ) + self.saturation = ( + saturation + if isinstance(saturation, list) + else [max(0, 1 - saturation), 1 + saturation] + ) + self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + if self.consistent_transform: + # Create a color jitter transformation params + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for img in datapoint.frames: + if not self.consistent_transform: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img.data = F.adjust_brightness(img.data, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img.data = F.adjust_contrast(img.data, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img.data = F.adjust_saturation(img.data, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img.data = F.adjust_hue(img.data, hue_factor) + return datapoint + + +class RandomAffine: + def __init__( + self, + degrees, + consistent_transform, + scale=None, + translate=None, + shear=None, + image_mean=(123, 116, 103), + log_warning=True, + num_tentatives=1, + image_interpolation="bicubic", + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random affine is applied to all frames and masks. + """ + self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) + self.scale = scale + self.shear = ( + shear if isinstance(shear, list) else ([-shear, shear] if shear else None) + ) + self.translate = translate + self.fill_img = image_mean + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + + if image_interpolation == "bicubic": + self.image_interpolation = InterpolationMode.BICUBIC + elif image_interpolation == "bilinear": + self.image_interpolation = InterpolationMode.BILINEAR + else: + raise NotImplementedError + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + for _tentative in range(self.num_tentatives): + res = self.transform_datapoint(datapoint) + if res is not None: + return res + + if self.log_warning: + logging.warning( + f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" + ) + return datapoint + + def transform_datapoint(self, datapoint: VideoDatapoint): + _, height, width = F.get_dimensions(datapoint.frames[0].data) + img_size = [width, height] + + if self.consistent_transform: + # Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + for img_idx, img in enumerate(datapoint.frames): + this_masks = [ + obj.segment.unsqueeze(0) if obj.segment is not None else None + for obj in img.objects + ] + if not self.consistent_transform: + # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + transformed_bboxes, transformed_masks = [], [] + for i in range(len(img.objects)): + if this_masks[i] is None: + transformed_masks.append(None) + # Dummy bbox for a dummy target + transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]])) + else: + transformed_mask = F.affine( + this_masks[i], + *affine_params, + interpolation=InterpolationMode.NEAREST, + fill=0.0, + ) + if img_idx == 0 and transformed_mask.max() == 0: + # We are dealing with a video and the object is not visible in the first frame + # Return the datapoint without transformation + return None + transformed_masks.append(transformed_mask.squeeze()) + + for i in range(len(img.objects)): + img.objects[i].segment = transformed_masks[i] + + img.data = F.affine( + img.data, + *affine_params, + interpolation=self.image_interpolation, + fill=self.fill_img, + ) + return datapoint + + +def random_mosaic_frame( + datapoint, + index, + grid_h, + grid_w, + target_grid_y, + target_grid_x, + should_hflip, +): + # Step 1: downsize the images and paste them into a mosaic + image_data = datapoint.frames[index].data + is_pil = isinstance(image_data, PILImage.Image) + if is_pil: + H_im = image_data.height + W_im = image_data.width + image_data_output = PILImage.new("RGB", (W_im, H_im)) + else: + H_im = image_data.size(-2) + W_im = image_data.size(-1) + image_data_output = torch.zeros_like(image_data) + + downsize_cache = {} + for grid_y in range(grid_h): + for grid_x in range(grid_w): + y_offset_b = grid_y * H_im // grid_h + x_offset_b = grid_x * W_im // grid_w + y_offset_e = (grid_y + 1) * H_im // grid_h + x_offset_e = (grid_x + 1) * W_im // grid_w + H_im_downsize = y_offset_e - y_offset_b + W_im_downsize = x_offset_e - x_offset_b + + if (H_im_downsize, W_im_downsize) in downsize_cache: + image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] + else: + image_data_downsize = F.resize( + image_data, + size=(H_im_downsize, W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + ) + downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize + if should_hflip[grid_y, grid_x].item(): + image_data_downsize = F.hflip(image_data_downsize) + + if is_pil: + image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) + else: + image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( + image_data_downsize + ) + + datapoint.frames[index].data = image_data_output + + # Step 2: downsize the masks and paste them into the target grid of the mosaic + for obj in datapoint.frames[index].objects: + if obj.segment is None: + continue + assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 + segment_output = torch.zeros_like(obj.segment) + + target_y_offset_b = target_grid_y * H_im // grid_h + target_x_offset_b = target_grid_x * W_im // grid_w + target_y_offset_e = (target_grid_y + 1) * H_im // grid_h + target_x_offset_e = (target_grid_x + 1) * W_im // grid_w + target_H_im_downsize = target_y_offset_e - target_y_offset_b + target_W_im_downsize = target_x_offset_e - target_x_offset_b + + segment_downsize = F.resize( + obj.segment[None, None], + size=(target_H_im_downsize, target_W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + )[0, 0] + if should_hflip[target_grid_y, target_grid_x].item(): + segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] + + segment_output[ + target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e + ] = segment_downsize + obj.segment = segment_output + + return datapoint + + +class RandomMosaicVideoAPI: + def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): + self.prob = prob + self.grid_h = grid_h + self.grid_w = grid_w + self.use_random_hflip = use_random_hflip + + def __call__(self, datapoint, **kwargs): + if random.random() > self.prob: + return datapoint + + # select a random location to place the target mask in the mosaic + target_grid_y = random.randint(0, self.grid_h - 1) + target_grid_x = random.randint(0, self.grid_w - 1) + # whether to flip each grid in the mosaic horizontally + if self.use_random_hflip: + should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 + else: + should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) + for i in range(len(datapoint.frames)): + datapoint = random_mosaic_frame( + datapoint, + i, + grid_h=self.grid_h, + grid_w=self.grid_w, + target_grid_y=target_grid_y, + target_grid_x=target_grid_x, + should_hflip=should_hflip, + ) + + return datapoint diff --git a/monst3r/third_party/sam2/training/dataset/utils.py b/monst3r/third_party/sam2/training/dataset/utils.py new file mode 100644 index 0000000..a658df2 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" + +from typing import Iterable + +import torch +from torch.utils.data import ( + ConcatDataset as TorchConcatDataset, + Dataset, + Subset as TorchSubset, +) + + +class ConcatDataset(TorchConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ConcatDataset, self).__init__(datasets) + + self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) + + def set_epoch(self, epoch: int): + for dataset in self.datasets: + if hasattr(dataset, "epoch"): + dataset.epoch = epoch + if hasattr(dataset, "set_epoch"): + dataset.set_epoch(epoch) + + +class Subset(TorchSubset): + def __init__(self, dataset, indices) -> None: + super(Subset, self).__init__(dataset, indices) + + self.repeat_factors = dataset.repeat_factors[indices] + assert len(indices) == len(self.repeat_factors) + + +# Adapted from Detectron2 +class RepeatFactorWrapper(Dataset): + """ + Thin wrapper around a dataset to implement repeat factor sampling. + The underlying dataset must have a repeat_factors member to indicate the per-image factor. + Set it to uniformly ones to disable repeat factor sampling + """ + + def __init__(self, dataset, seed: int = 0): + self.dataset = dataset + self.epoch_ids = None + self._seed = seed + + # Split into whole number (_int_part) and fractional (_frac_part) parts. + self._int_part = torch.trunc(dataset.repeat_factors) + self._frac_part = dataset.repeat_factors - self._int_part + + def _get_epoch_indices(self, generator): + """ + Create a list of dataset indices (with repeats) to use for one epoch. + + Args: + generator (torch.Generator): pseudo random number generator used for + stochastic rounding. + + Returns: + torch.Tensor: list of dataset indices to use in one epoch. Each index + is repeated based on its calculated repeat factor. + """ + # Since repeat factors are fractional, we use stochastic rounding so + # that the target repeat factor is achieved in expectation over the + # course of training + rands = torch.rand(len(self._frac_part), generator=generator) + rep_factors = self._int_part + (rands < self._frac_part).float() + # Construct a list of indices in which we repeat images as specified + indices = [] + for dataset_index, rep_factor in enumerate(rep_factors): + indices.extend([dataset_index] * int(rep_factor.item())) + return torch.tensor(indices, dtype=torch.int64) + + def __len__(self): + if self.epoch_ids is None: + # Here we raise an error instead of returning original len(self.dataset) avoid + # accidentally using unwrapped length. Otherwise it's error-prone since the + # length changes to `len(self.epoch_ids)`changes after set_epoch is called. + raise RuntimeError("please call set_epoch first to get wrapped length") + # return len(self.dataset) + + return len(self.epoch_ids) + + def set_epoch(self, epoch: int): + g = torch.Generator() + g.manual_seed(self._seed + epoch) + self.epoch_ids = self._get_epoch_indices(g) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + def __getitem__(self, idx): + if self.epoch_ids is None: + raise RuntimeError( + "Repeat ids haven't been computed. Did you forget to call set_epoch?" + ) + + return self.dataset[self.epoch_ids[idx]] diff --git a/monst3r/third_party/sam2/training/dataset/vos_dataset.py b/monst3r/third_party/sam2/training/dataset/vos_dataset.py new file mode 100644 index 0000000..d1e9d39 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/vos_dataset.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import random +from copy import deepcopy + +import numpy as np + +import torch +from iopath.common.file_io import g_pathmgr +from PIL import Image as PILImage +from torchvision.datasets.vision import VisionDataset + +from training.dataset.vos_raw_dataset import VOSRawDataset +from training.dataset.vos_sampler import VOSSampler +from training.dataset.vos_segment_loader import JSONSegmentLoader + +from training.utils.data_utils import Frame, Object, VideoDatapoint + +MAX_RETRIES = 100 + + +class VOSDataset(VisionDataset): + def __init__( + self, + transforms, + training: bool, + video_dataset: VOSRawDataset, + sampler: VOSSampler, + multiplier: int, + always_target=True, + target_segments_available=True, + ): + self._transforms = transforms + self.training = training + self.video_dataset = video_dataset + self.sampler = sampler + + self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) + self.repeat_factors *= multiplier + print(f"Raw dataset length = {len(self.video_dataset)}") + + self.curr_epoch = 0 # Used in case data loader behavior changes across epochs + self.always_target = always_target + self.target_segments_available = target_segments_available + + def _get_datapoint(self, idx): + + for retry in range(MAX_RETRIES): + try: + if isinstance(idx, torch.Tensor): + idx = idx.item() + # sample a video + video, segment_loader = self.video_dataset.get_video(idx) + # sample frames and object indices to be used in a datapoint + sampled_frms_and_objs = self.sampler.sample( + video, segment_loader, epoch=self.curr_epoch + ) + break # Succesfully loaded video + except Exception as e: + if self.training: + logging.warning( + f"Loading failed (id={idx}); Retry {retry} with exception: {e}" + ) + idx = random.randrange(0, len(self.video_dataset)) + else: + # Shouldn't fail to load a val video + raise e + + datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) + for transform in self._transforms: + datapoint = transform(datapoint, epoch=self.curr_epoch) + return datapoint + + def construct(self, video, sampled_frms_and_objs, segment_loader): + """ + Constructs a VideoDatapoint sample to pass to transforms + """ + sampled_frames = sampled_frms_and_objs.frames + sampled_object_ids = sampled_frms_and_objs.object_ids + + images = [] + rgb_images = load_images(sampled_frames) + # Iterate over the sampled frames and store their rgb data and object data (bbox, segment) + for frame_idx, frame in enumerate(sampled_frames): + w, h = rgb_images[frame_idx].size + images.append( + Frame( + data=rgb_images[frame_idx], + objects=[], + ) + ) + # We load the gt segments associated with the current frame + if isinstance(segment_loader, JSONSegmentLoader): + segments = segment_loader.load( + frame.frame_idx, obj_ids=sampled_object_ids + ) + else: + segments = segment_loader.load(frame.frame_idx) + for obj_id in sampled_object_ids: + # Extract the segment + if obj_id in segments: + assert ( + segments[obj_id] is not None + ), "None targets are not supported" + # segment is uint8 and remains uint8 throughout the transforms + segment = segments[obj_id].to(torch.uint8) + else: + # There is no target, we either use a zero mask target or drop this object + if not self.always_target: + continue + segment = torch.zeros(h, w, dtype=torch.uint8) + + images[frame_idx].objects.append( + Object( + object_id=obj_id, + frame_index=frame.frame_idx, + segment=segment, + ) + ) + return VideoDatapoint( + frames=images, + video_id=video.video_id, + size=(h, w), + ) + + def __getitem__(self, idx): + return self._get_datapoint(idx) + + def __len__(self): + return len(self.video_dataset) + + +def load_images(frames): + all_images = [] + cache = {} + for frame in frames: + if frame.data is None: + # Load the frame rgb data from file + path = frame.image_path + if path in cache: + all_images.append(deepcopy(all_images[cache[path]])) + continue + with g_pathmgr.open(path, "rb") as fopen: + all_images.append(PILImage.open(fopen).convert("RGB")) + cache[path] = len(all_images) - 1 + else: + # The frame rgb data has already been loaded + # Convert it to a PILImage + all_images.append(tensor_2_PIL(frame.data)) + + return all_images + + +def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: + data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 + data = data.astype(np.uint8) + return PILImage.fromarray(data) diff --git a/monst3r/third_party/sam2/training/dataset/vos_raw_dataset.py b/monst3r/third_party/sam2/training/dataset/vos_raw_dataset.py new file mode 100644 index 0000000..44fe893 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/vos_raw_dataset.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import logging +import os +from dataclasses import dataclass + +from typing import List, Optional + +import pandas as pd + +import torch + +from iopath.common.file_io import g_pathmgr + +from omegaconf.listconfig import ListConfig + +from training.dataset.vos_segment_loader import ( + JSONSegmentLoader, + MultiplePNGSegmentLoader, + PalettisedPNGSegmentLoader, + SA1BSegmentLoader, +) + + +@dataclass +class VOSFrame: + frame_idx: int + image_path: str + data: Optional[torch.Tensor] = None + is_conditioning_only: Optional[bool] = False + + +@dataclass +class VOSVideo: + video_name: str + video_id: int + frames: List[VOSFrame] + + def __len__(self): + return len(self.frames) + + +class VOSRawDataset: + def __init__(self): + pass + + def get_video(self, idx): + raise NotImplementedError() + + +class PNGRawDataset(VOSRawDataset): + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + sample_rate=1, + is_palette=True, + single_object_mode=False, + truncate_video=-1, + frames_sampling_mult=False, + ): + self.img_folder = img_folder + self.gt_folder = gt_folder + self.sample_rate = sample_rate + self.is_palette = is_palette + self.single_object_mode = single_object_mode + self.truncate_video = truncate_video + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + + # Read and process excluded files if provided + if excluded_videos_list_txt is not None: + with g_pathmgr.open(excluded_videos_list_txt, "r") as f: + excluded_files = [os.path.splitext(line.strip())[0] for line in f] + else: + excluded_files = [] + + # Check if it's not in excluded_files + self.video_names = sorted( + [video_name for video_name in subset if video_name not in excluded_files] + ) + + if self.single_object_mode: + # single object mode + self.video_names = sorted( + [ + os.path.join(video_name, obj) + for video_name in self.video_names + for obj in os.listdir(os.path.join(self.gt_folder, video_name)) + ] + ) + + if frames_sampling_mult: + video_names_mult = [] + for video_name in self.video_names: + num_frames = len(os.listdir(os.path.join(self.img_folder, video_name))) + video_names_mult.extend([video_name] * num_frames) + self.video_names = video_names_mult + + def get_video(self, idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[idx] + + if self.single_object_mode: + video_frame_root = os.path.join( + self.img_folder, os.path.dirname(video_name) + ) + else: + video_frame_root = os.path.join(self.img_folder, video_name) + + video_mask_root = os.path.join(self.gt_folder, video_name) + + if self.is_palette: + segment_loader = PalettisedPNGSegmentLoader(video_mask_root) + else: + segment_loader = MultiplePNGSegmentLoader( + video_mask_root, self.single_object_mode + ) + + all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg"))) + if self.truncate_video > 0: + all_frames = all_frames[: self.truncate_video] + frames = [] + for _, fpath in enumerate(all_frames[:: self.sample_rate]): + fid = int(os.path.basename(fpath).split(".")[0]) + frames.append(VOSFrame(fid, image_path=fpath)) + video = VOSVideo(video_name, idx, frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) + + +class SA1BRawDataset(VOSRawDataset): + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + num_frames=1, + mask_area_frac_thresh=1.1, # no filtering by default + uncertain_iou=-1, # no filtering by default + ): + self.img_folder = img_folder + self.gt_folder = gt_folder + self.num_frames = num_frames + self.mask_area_frac_thresh = mask_area_frac_thresh + self.uncertain_iou = uncertain_iou # stability score + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + subset = [ + path.split(".")[0] for path in subset if path.endswith(".jpg") + ] # remove extension + + # Read and process excluded files if provided + if excluded_videos_list_txt is not None: + with g_pathmgr.open(excluded_videos_list_txt, "r") as f: + excluded_files = [os.path.splitext(line.strip())[0] for line in f] + else: + excluded_files = [] + + # Check if it's not in excluded_files and it exists + self.video_names = [ + video_name for video_name in subset if video_name not in excluded_files + ] + + def get_video(self, idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[idx] + + video_frame_path = os.path.join(self.img_folder, video_name + ".jpg") + video_mask_path = os.path.join(self.gt_folder, video_name + ".json") + + segment_loader = SA1BSegmentLoader( + video_mask_path, + mask_area_frac_thresh=self.mask_area_frac_thresh, + video_frame_path=video_frame_path, + uncertain_iou=self.uncertain_iou, + ) + + frames = [] + for frame_idx in range(self.num_frames): + frames.append(VOSFrame(frame_idx, image_path=video_frame_path)) + video_name = video_name.split("_")[-1] # filename is sa_{int} + # video id needs to be image_id to be able to load correct annotation file during eval + video = VOSVideo(video_name, int(video_name), frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) + + +class JSONRawDataset(VOSRawDataset): + """ + Dataset where the annotation in the format of SA-V json files + """ + + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + sample_rate=1, + rm_unannotated=True, + ann_every=1, + frames_fps=24, + ): + self.gt_folder = gt_folder + self.img_folder = img_folder + self.sample_rate = sample_rate + self.rm_unannotated = rm_unannotated + self.ann_every = ann_every + self.frames_fps = frames_fps + + # Read and process excluded files if provided + excluded_files = [] + if excluded_videos_list_txt is not None: + if isinstance(excluded_videos_list_txt, str): + excluded_videos_lists = [excluded_videos_list_txt] + elif isinstance(excluded_videos_list_txt, ListConfig): + excluded_videos_lists = list(excluded_videos_list_txt) + else: + raise NotImplementedError + + for excluded_videos_list_txt in excluded_videos_lists: + with open(excluded_videos_list_txt, "r") as f: + excluded_files.extend( + [os.path.splitext(line.strip())[0] for line in f] + ) + excluded_files = set(excluded_files) + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + + self.video_names = sorted( + [video_name for video_name in subset if video_name not in excluded_files] + ) + + def get_video(self, video_idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[video_idx] + video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json") + segment_loader = JSONSegmentLoader( + video_json_path=video_json_path, + ann_every=self.ann_every, + frames_fps=self.frames_fps, + ) + + frame_ids = [ + int(os.path.splitext(frame_name)[0]) + for frame_name in sorted( + os.listdir(os.path.join(self.img_folder, video_name)) + ) + ] + + frames = [ + VOSFrame( + frame_id, + image_path=os.path.join( + self.img_folder, f"{video_name}/%05d.jpg" % (frame_id) + ), + ) + for frame_id in frame_ids[:: self.sample_rate] + ] + + if self.rm_unannotated: + # Eliminate the frames that have not been annotated + valid_frame_ids = [ + i * segment_loader.ann_every + for i, annot in enumerate(segment_loader.frame_annots) + if annot is not None and None not in annot + ] + frames = [f for f in frames if f.frame_idx in valid_frame_ids] + + video = VOSVideo(video_name, video_idx, frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) diff --git a/monst3r/third_party/sam2/training/dataset/vos_sampler.py b/monst3r/third_party/sam2/training/dataset/vos_sampler.py new file mode 100644 index 0000000..1ad84b7 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/vos_sampler.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +from dataclasses import dataclass +from typing import List + +from training.dataset.vos_segment_loader import LazySegments + +MAX_RETRIES = 1000 + + +@dataclass +class SampledFramesAndObjects: + frames: List[int] + object_ids: List[int] + + +class VOSSampler: + def __init__(self, sort_frames=True): + # frames are ordered by frame id when sort_frames is True + self.sort_frames = sort_frames + + def sample(self, video): + raise NotImplementedError() + + +class RandomUniformSampler(VOSSampler): + def __init__( + self, + num_frames, + max_num_objects, + reverse_time_prob=0.0, + ): + self.num_frames = num_frames + self.max_num_objects = max_num_objects + self.reverse_time_prob = reverse_time_prob + + def sample(self, video, segment_loader, epoch=None): + + for retry in range(MAX_RETRIES): + if len(video.frames) < self.num_frames: + raise Exception( + f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." + ) + start = random.randrange(0, len(video.frames) - self.num_frames + 1) + frames = [video.frames[start + step] for step in range(self.num_frames)] + if random.uniform(0, 1) < self.reverse_time_prob: + # Reverse time + frames = frames[::-1] + + # Get first frame object ids + visible_object_ids = [] + loaded_segms = segment_loader.load(frames[0].frame_idx) + if isinstance(loaded_segms, LazySegments): + # LazySegments for SA1BRawDataset + visible_object_ids = list(loaded_segms.keys()) + else: + for object_id, segment in segment_loader.load( + frames[0].frame_idx + ).items(): + if segment.sum(): + visible_object_ids.append(object_id) + + # First frame needs to have at least a target to track + if len(visible_object_ids) > 0: + break + if retry >= MAX_RETRIES - 1: + raise Exception("No visible objects") + + object_ids = random.sample( + visible_object_ids, + min(len(visible_object_ids), self.max_num_objects), + ) + return SampledFramesAndObjects(frames=frames, object_ids=object_ids) + + +class EvalSampler(VOSSampler): + """ + VOS Sampler for evaluation: sampling all the frames and all the objects in a video + """ + + def __init__( + self, + ): + super().__init__() + + def sample(self, video, segment_loader, epoch=None): + """ + Sampling all the frames and all the objects + """ + if self.sort_frames: + # ordered by frame id + frames = sorted(video.frames, key=lambda x: x.frame_idx) + else: + # use the original order + frames = video.frames + object_ids = segment_loader.load(frames[0].frame_idx).keys() + if len(object_ids) == 0: + raise Exception("First frame of the video has no objects") + + return SampledFramesAndObjects(frames=frames, object_ids=object_ids) diff --git a/monst3r/third_party/sam2/training/dataset/vos_segment_loader.py b/monst3r/third_party/sam2/training/dataset/vos_segment_loader.py new file mode 100644 index 0000000..27e1701 --- /dev/null +++ b/monst3r/third_party/sam2/training/dataset/vos_segment_loader.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import json +import os + +import numpy as np +import pandas as pd +import torch + +from PIL import Image as PILImage + +try: + from pycocotools import mask as mask_utils +except: + pass + + +class JSONSegmentLoader: + def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): + # Annotations in the json are provided every ann_every th frame + self.ann_every = ann_every + # Ids of the objects to consider when sampling this video + self.valid_obj_ids = valid_obj_ids + with open(video_json_path, "r") as f: + data = json.load(f) + if isinstance(data, list): + self.frame_annots = data + elif isinstance(data, dict): + masklet_field_name = "masklet" if "masklet" in data else "masks" + self.frame_annots = data[masklet_field_name] + if "fps" in data: + if isinstance(data["fps"], list): + annotations_fps = int(data["fps"][0]) + else: + annotations_fps = int(data["fps"]) + assert frames_fps % annotations_fps == 0 + self.ann_every = frames_fps // annotations_fps + else: + raise NotImplementedError + + def load(self, frame_id, obj_ids=None): + assert frame_id % self.ann_every == 0 + rle_mask = self.frame_annots[frame_id // self.ann_every] + + valid_objs_ids = set(range(len(rle_mask))) + if self.valid_obj_ids is not None: + # Remove the masklets that have been filtered out for this video + valid_objs_ids &= set(self.valid_obj_ids) + if obj_ids is not None: + # Only keep the objects that have been sampled + valid_objs_ids &= set(obj_ids) + valid_objs_ids = sorted(list(valid_objs_ids)) + + # Construct rle_masks_filtered that only contains the rle masks we are interested in + id_2_idx = {} + rle_mask_filtered = [] + for obj_id in valid_objs_ids: + if rle_mask[obj_id] is not None: + id_2_idx[obj_id] = len(rle_mask_filtered) + rle_mask_filtered.append(rle_mask[obj_id]) + else: + id_2_idx[obj_id] = None + + # Decode the masks + raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( + 2, 0, 1 + ) # (num_obj, h, w) + segments = {} + for obj_id in valid_objs_ids: + if id_2_idx[obj_id] is None: + segments[obj_id] = None + else: + idx = id_2_idx[obj_id] + segments[obj_id] = raw_segments[idx] + return segments + + def get_valid_obj_frames_ids(self, num_frames_min=None): + # For each object, find all the frames with a valid (not None) mask + num_objects = len(self.frame_annots[0]) + + # The result dict associates each obj_id with the id of its valid frames + res = {obj_id: [] for obj_id in range(num_objects)} + + for annot_idx, annot in enumerate(self.frame_annots): + for obj_id in range(num_objects): + if annot[obj_id] is not None: + res[obj_id].append(int(annot_idx * self.ann_every)) + + if num_frames_min is not None: + # Remove masklets that have less than num_frames_min valid masks + for obj_id, valid_frames in list(res.items()): + if len(valid_frames) < num_frames_min: + res.pop(obj_id) + + return res + + +class PalettisedPNGSegmentLoader: + def __init__(self, video_png_root): + """ + SegmentLoader for datasets with masks stored as palettised PNGs. + video_png_root: the folder contains all the masks stored in png + """ + self.video_png_root = video_png_root + # build a mapping from frame id to their PNG mask path + # note that in some datasets, the PNG paths could have more + # than 5 digits, e.g. "00000000.png" instead of "00000.png" + png_filenames = os.listdir(self.video_png_root) + self.frame_id_to_png_filename = {} + for filename in png_filenames: + frame_id, _ = os.path.splitext(filename) + self.frame_id_to_png_filename[int(frame_id)] = filename + + def load(self, frame_id): + """ + load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + # check the path + mask_path = os.path.join( + self.video_png_root, self.frame_id_to_png_filename[frame_id] + ) + + # load the mask + masks = PILImage.open(mask_path).convert("P") + masks = np.array(masks) + + object_id = pd.unique(masks.flatten()) + object_id = object_id[object_id != 0] # remove background (0) + + # convert into N binary segmentation masks + binary_segments = {} + for i in object_id: + bs = masks == i + binary_segments[i] = torch.from_numpy(bs) + + return binary_segments + + def __len__(self): + return + + +class MultiplePNGSegmentLoader: + def __init__(self, video_png_root, single_object_mode=False): + """ + video_png_root: the folder contains all the masks stored in png + single_object_mode: whether to load only a single object at a time + """ + self.video_png_root = video_png_root + self.single_object_mode = single_object_mode + # read a mask to know the resolution of the video + if self.single_object_mode: + tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] + else: + tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] + tmp_mask = np.array(PILImage.open(tmp_mask_path)) + self.H = tmp_mask.shape[0] + self.W = tmp_mask.shape[1] + if self.single_object_mode: + self.obj_id = ( + int(video_png_root.split("/")[-1]) + 1 + ) # offset by 1 as bg is 0 + else: + self.obj_id = None + + def load(self, frame_id): + if self.single_object_mode: + return self._load_single_png(frame_id) + else: + return self._load_multiple_pngs(frame_id) + + def _load_single_png(self, frame_id): + """ + load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") + binary_segments = {} + + if os.path.exists(mask_path): + mask = np.array(PILImage.open(mask_path)) + else: + # if png doesn't exist, empty mask + mask = np.zeros((self.H, self.W), dtype=bool) + binary_segments[self.obj_id] = torch.from_numpy(mask > 0) + return binary_segments + + def _load_multiple_pngs(self, frame_id): + """ + load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + # get the path + all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) + num_objects = len(all_objects) + assert num_objects > 0 + + # load the masks + binary_segments = {} + for obj_folder in all_objects: + # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder + obj_id = int(obj_folder.split("/")[-1]) + obj_id = obj_id + 1 # offset 1 as bg is 0 + mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") + if os.path.exists(mask_path): + mask = np.array(PILImage.open(mask_path)) + else: + mask = np.zeros((self.H, self.W), dtype=bool) + binary_segments[obj_id] = torch.from_numpy(mask > 0) + + return binary_segments + + def __len__(self): + return + + +class LazySegments: + """ + Only decodes segments that are actually used. + """ + + def __init__(self): + self.segments = {} + self.cache = {} + + def __setitem__(self, key, item): + self.segments[key] = item + + def __getitem__(self, key): + if key in self.cache: + return self.cache[key] + rle = self.segments[key] + mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] + self.cache[key] = mask + return mask + + def __contains__(self, key): + return key in self.segments + + def __len__(self): + return len(self.segments) + + def keys(self): + return self.segments.keys() + + +class SA1BSegmentLoader: + def __init__( + self, + video_mask_path, + mask_area_frac_thresh=1.1, + video_frame_path=None, + uncertain_iou=-1, + ): + with open(video_mask_path, "r") as f: + self.frame_annots = json.load(f) + + if mask_area_frac_thresh <= 1.0: + # Lazily read frame + orig_w, orig_h = PILImage.open(video_frame_path).size + area = orig_w * orig_h + + self.frame_annots = self.frame_annots["annotations"] + + rle_masks = [] + for frame_annot in self.frame_annots: + if not frame_annot["area"] > 0: + continue + if ("uncertain_iou" in frame_annot) and ( + frame_annot["uncertain_iou"] < uncertain_iou + ): + # uncertain_iou is stability score + continue + if ( + mask_area_frac_thresh <= 1.0 + and (frame_annot["area"] / area) >= mask_area_frac_thresh + ): + continue + rle_masks.append(frame_annot["segmentation"]) + + self.segments = LazySegments() + for i, rle in enumerate(rle_masks): + self.segments[i] = rle + + def load(self, frame_idx): + return self.segments diff --git a/monst3r/third_party/sam2/training/loss_fns.py b/monst3r/third_party/sam2/training/loss_fns.py new file mode 100644 index 0000000..d281b1a --- /dev/null +++ b/monst3r/third_party/sam2/training/loss_fns.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from typing import Dict, List + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F + +from training.trainer import CORE_LOSS_KEY + +from training.utils.distributed import get_world_size, is_dist_avail_and_initialized + + +def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + num_objects: Number of objects in the batch + loss_on_multimask: True if multimask prediction is enabled + Returns: + Dice loss tensor + """ + inputs = inputs.sigmoid() + if loss_on_multimask: + # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks + assert inputs.dim() == 4 and targets.dim() == 4 + # flatten spatial dimension while keeping multimask channel dimension + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +def sigmoid_focal_loss( + inputs, + targets, + num_objects, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + num_objects: Number of objects in the batch + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + loss_on_multimask: True if multimask prediction is enabled + Returns: + focal loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if loss_on_multimask: + # loss is [N, M, H, W] where M corresponds to multiple predicted masks + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_objects # average over spatial dims + return loss.mean(1).sum() / num_objects + + +def iou_loss( + inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False +): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + pred_ious: A float tensor containing the predicted IoUs scores per mask + num_objects: Number of objects in the batch + loss_on_multimask: True if multimask prediction is enabled + use_l1_loss: Whether to use L1 loss is used instead of MSE loss + Returns: + IoU loss tensor + """ + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +class MultiStepMultiMasksAndIous(nn.Module): + def __init__( + self, + weight_dict, + focal_alpha=0.25, + focal_gamma=2, + supervise_all_iou=False, + iou_use_l1_loss=False, + pred_obj_scores=False, + focal_gamma_obj_score=0.0, + focal_alpha_obj_score=-1, + ): + """ + This class computes the multi-step multi-mask and IoU losses. + Args: + weight_dict: dict containing weights for focal, dice, iou losses + focal_alpha: alpha for sigmoid focal loss + focal_gamma: gamma for sigmoid focal loss + supervise_all_iou: if True, back-prop iou losses for all predicted masks + iou_use_l1_loss: use L1 loss instead of MSE loss for iou + pred_obj_scores: if True, compute loss for object scores + focal_gamma_obj_score: gamma for sigmoid focal loss on object scores + focal_alpha_obj_score: alpha for sigmoid focal loss on object scores + """ + + super().__init__() + self.weight_dict = weight_dict + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + assert "loss_mask" in self.weight_dict + assert "loss_dice" in self.weight_dict + assert "loss_iou" in self.weight_dict + if "loss_class" not in self.weight_dict: + self.weight_dict["loss_class"] = 0.0 + + self.focal_alpha_obj_score = focal_alpha_obj_score + self.focal_gamma_obj_score = focal_gamma_obj_score + self.supervise_all_iou = supervise_all_iou + self.iou_use_l1_loss = iou_use_l1_loss + self.pred_obj_scores = pred_obj_scores + + def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): + assert len(outs_batch) == len(targets_batch) + num_objects = torch.tensor( + (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float + ) # Number of objects is fixed within a batch + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_objects) + num_objects = torch.clamp(num_objects / get_world_size(), min=1).item() + + losses = defaultdict(int) + for outs, targets in zip(outs_batch, targets_batch): + cur_losses = self._forward(outs, targets, num_objects) + for k, v in cur_losses.items(): + losses[k] += v + + return losses + + def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + and also the MAE or MSE loss between predicted IoUs and actual IoUs. + + Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors + of shape [N, M, H, W], where M could be 1 or larger, corresponding to + one or multiple predicted masks from a click. + + We back-propagate focal, dice losses only on the prediction channel + with the lowest focal+dice loss between predicted mask and ground-truth. + If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks. + """ + + target_masks = targets.unsqueeze(1).float() + assert target_masks.dim() == 4 # [N, 1, H, W] + src_masks_list = outputs["multistep_pred_multimasks_high_res"] + ious_list = outputs["multistep_pred_ious"] + object_score_logits_list = outputs["multistep_object_score_logits"] + + assert len(src_masks_list) == len(ious_list) + assert len(object_score_logits_list) == len(ious_list) + + # accumulate the loss over prediction steps + losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} + for src_masks, ious, object_score_logits in zip( + src_masks_list, ious_list, object_score_logits_list + ): + self._update_losses( + losses, src_masks, target_masks, ious, num_objects, object_score_logits + ) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def _update_losses( + self, losses, src_masks, target_masks, ious, num_objects, object_score_logits + ): + target_masks = target_masks.expand_as(src_masks) + # get focal, dice and iou loss on all output masks in a prediction step + loss_multimask = sigmoid_focal_loss( + src_masks, + target_masks, + num_objects, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + loss_on_multimask=True, + ) + loss_multidice = dice_loss( + src_masks, target_masks, num_objects, loss_on_multimask=True + ) + if not self.pred_obj_scores: + loss_class = torch.tensor( + 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device + ) + target_obj = torch.ones( + loss_multimask.shape[0], + 1, + dtype=loss_multimask.dtype, + device=loss_multimask.device, + ) + else: + target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ + ..., None + ].float() + loss_class = sigmoid_focal_loss( + object_score_logits, + target_obj, + num_objects, + alpha=self.focal_alpha_obj_score, + gamma=self.focal_gamma_obj_score, + ) + + loss_multiiou = iou_loss( + src_masks, + target_masks, + ious, + num_objects, + loss_on_multimask=True, + use_l1_loss=self.iou_use_l1_loss, + ) + assert loss_multimask.dim() == 2 + assert loss_multidice.dim() == 2 + assert loss_multiiou.dim() == 2 + if loss_multimask.size(1) > 1: + # take the mask indices with the smallest focal + dice loss for back propagation + loss_combo = ( + loss_multimask * self.weight_dict["loss_mask"] + + loss_multidice * self.weight_dict["loss_dice"] + ) + best_loss_inds = torch.argmin(loss_combo, dim=-1) + batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) + loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) + loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) + # calculate the iou prediction and slot losses only in the index + # with the minimum loss for each mask (to be consistent w/ SAM) + if self.supervise_all_iou: + loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + else: + loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + else: + loss_mask = loss_multimask + loss_dice = loss_multidice + loss_iou = loss_multiiou + + # backprop focal, dice and iou loss only if obj present + loss_mask = loss_mask * target_obj + loss_dice = loss_dice * target_obj + loss_iou = loss_iou * target_obj + + # sum over batch dimension (note that the losses are already divided by num_objects) + losses["loss_mask"] += loss_mask.sum() + losses["loss_dice"] += loss_dice.sum() + losses["loss_iou"] += loss_iou.sum() + losses["loss_class"] += loss_class + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + + return reduced_loss diff --git a/monst3r/third_party/sam2/training/model/__init__.py b/monst3r/third_party/sam2/training/model/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/training/model/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/training/model/sam2.py b/monst3r/third_party/sam2/training/model/sam2.py new file mode 100644 index 0000000..ef7567c --- /dev/null +++ b/monst3r/third_party/sam2/training/model/sam2.py @@ -0,0 +1,541 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +import torch.distributed +from sam2.modeling.sam2_base import SAM2Base +from sam2.modeling.sam2_utils import ( + get_1d_sine_pe, + get_next_point, + sample_box_points, + select_closest_cond_frames, +) + +from sam2.utils.misc import concat_points + +from training.utils.data_utils import BatchedVideoDatapoint + + +class SAM2Train(SAM2Base): + def __init__( + self, + image_encoder, + memory_attention=None, + memory_encoder=None, + prob_to_use_pt_input_for_train=0.0, + prob_to_use_pt_input_for_eval=0.0, + prob_to_use_box_input_for_train=0.0, + prob_to_use_box_input_for_eval=0.0, + # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames + num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame + num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame + rand_frames_to_correct_for_train=False, + rand_frames_to_correct_for_eval=False, + # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame) + # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames + # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames + # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`; + # these are initial conditioning frames because as we track the video, more conditioning frames might be added + # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True` + num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame + num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame + rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader) + rand_init_cond_frames_for_eval=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # how many additional correction points to sample (on each frame selected to be corrected) + # note that the first frame receives an initial input click (in addition to any correction clicks) + num_correction_pt_per_frame=7, + # method for point sampling during evaluation + # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary) + # default to "center" to be consistent with evaluation in the SAM paper + pt_sampling_for_eval="center", + # During training, we optionally allow sampling the correction points from GT regions + # instead of the prediction error regions with a small probability. This might allow the + # model to overfit less to the error regions in training datasets + prob_to_sample_from_gt_for_train=0.0, + use_act_ckpt_iterative_pt_sampling=False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval=False, + freeze_image_encoder=False, + **kwargs, + ): + super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs) + self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + + # Point sampler and conditioning frames + self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train + self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train + self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval + self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval + if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: + logging.info( + f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}" + ) + assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train + assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval + + self.num_frames_to_correct_for_train = num_frames_to_correct_for_train + self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval + self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train + self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval + # Initial multi-conditioning frames + self.num_init_cond_frames_for_train = num_init_cond_frames_for_train + self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval + self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train + self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.num_correction_pt_per_frame = num_correction_pt_per_frame + self.pt_sampling_for_eval = pt_sampling_for_eval + self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train + # A random number generator with a fixed initial seed across GPUs + self.rng = np.random.default_rng(seed=42) + + if freeze_image_encoder: + for p in self.image_encoder.parameters(): + p.requires_grad = False + + def forward(self, input: BatchedVideoDatapoint): + if self.training or not self.forward_backbone_per_frame_for_eval: + # precompute image features on all frames before tracking + backbone_out = self.forward_image(input.flat_img_batch) + else: + # defer image feature computation on a frame until it's being tracked + backbone_out = {"backbone_fpn": None, "vision_pos_enc": None} + backbone_out = self.prepare_prompt_inputs(backbone_out, input) + previous_stages_out = self.forward_tracking(backbone_out, input) + + return previous_stages_out + + def _prepare_backbone_features_per_frame(self, img_batch, img_ids): + """Compute the image backbone features on the fly for the given img_ids.""" + # Only forward backbone on unique image ids to avoid repetitive computation + # (if `img_ids` has only one element, it's already unique so we skip this step). + if img_ids.numel() > 1: + unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) + else: + unique_img_ids, inv_ids = img_ids, None + + # Compute the image features on those unique image ids + image = img_batch[unique_img_ids] + backbone_out = self.forward_image(image) + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + # Inverse-map image features for `unique_img_ids` to the final image features + # for the original input `img_ids`. + if inv_ids is not None: + image = image[inv_ids] + vision_feats = [x[:, inv_ids] for x in vision_feats] + vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] + + return image, vision_feats, vision_pos_embeds, feat_sizes + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + # Load the ground-truth masks on all frames (so that we can later + # sample correction points from them) + # gt_masks_per_frame = { + # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im] + # for stage_id, targets in enumerate(input.find_targets) + # } + gt_masks_per_frame = { + stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im] + for stage_id, masks in enumerate(input.masks) + } + # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + num_frames = input.num_frames + backbone_out["num_frames"] = num_frames + + # Randomly decide whether to use point inputs or mask inputs + if self.training: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_train + prob_to_use_box_input = self.prob_to_use_box_input_for_train + num_frames_to_correct = self.num_frames_to_correct_for_train + rand_frames_to_correct = self.rand_frames_to_correct_for_train + num_init_cond_frames = self.num_init_cond_frames_for_train + rand_init_cond_frames = self.rand_init_cond_frames_for_train + else: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval + prob_to_use_box_input = self.prob_to_use_box_input_for_eval + num_frames_to_correct = self.num_frames_to_correct_for_eval + rand_frames_to_correct = self.rand_frames_to_correct_for_eval + num_init_cond_frames = self.num_init_cond_frames_for_eval + rand_init_cond_frames = self.rand_init_cond_frames_for_eval + if num_frames == 1: + # here we handle a special case for mixing video + SAM on image training, + # where we force using point input for the SAM task on static images + prob_to_use_pt_input = 1.0 + num_frames_to_correct = 1 + num_init_cond_frames = 1 + assert num_init_cond_frames >= 1 + # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0) + use_pt_input = self.rng.random() < prob_to_use_pt_input + if rand_init_cond_frames and num_init_cond_frames > 1: + # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames + num_init_cond_frames = self.rng.integers( + 1, num_init_cond_frames, endpoint=True + ) + if ( + use_pt_input + and rand_frames_to_correct + and num_frames_to_correct > num_init_cond_frames + ): + # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample + # correction clicks (only for the case of point input) + num_frames_to_correct = self.rng.integers( + num_init_cond_frames, num_frames_to_correct, endpoint=True + ) + backbone_out["use_pt_input"] = use_pt_input + + # Sample initial conditioning frames + if num_init_cond_frames == 1: + init_cond_frames = [start_frame_idx] # starting frame + else: + # starting frame + randomly selected remaining frames (without replacement) + init_cond_frames = [start_frame_idx] + self.rng.choice( + range(start_frame_idx + 1, num_frames), + num_init_cond_frames - 1, + replace=False, + ).tolist() + backbone_out["init_cond_frames"] = init_cond_frames + backbone_out["frames_not_in_init_cond"] = [ + t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames + ] + # Prepare mask or point inputs on initial conditioning frames + backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: } + backbone_out["point_inputs_per_frame"] = {} # {frame_idx: } + for t in init_cond_frames: + if not use_pt_input: + backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] + else: + # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input + use_box_input = self.rng.random() < prob_to_use_box_input + if use_box_input: + points, labels = sample_box_points( + gt_masks_per_frame[t], + ) + else: + # (here we only sample **one initial point** on initial conditioning frames from the + # ground-truth mask; we may sample more correction points on the fly) + points, labels = get_next_point( + gt_masks=gt_masks_per_frame[t], + pred_masks=None, + method=( + "uniform" if self.training else self.pt_sampling_for_eval + ), + ) + + point_inputs = {"point_coords": points, "point_labels": labels} + backbone_out["point_inputs_per_frame"][t] = point_inputs + + # Sample frames where we will add correction clicks on the fly + # based on the error between prediction and ground-truth masks + if not use_pt_input: + # no correction points will be sampled when using mask inputs + frames_to_add_correction_pt = [] + elif num_frames_to_correct == num_init_cond_frames: + frames_to_add_correction_pt = init_cond_frames + else: + assert num_frames_to_correct > num_init_cond_frames + # initial cond frame + randomly selected remaining frames (without replacement) + extra_num = num_frames_to_correct - num_init_cond_frames + frames_to_add_correction_pt = ( + init_cond_frames + + self.rng.choice( + backbone_out["frames_not_in_init_cond"], extra_num, replace=False + ).tolist() + ) + backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt + + return backbone_out + + def forward_tracking( + self, backbone_out, input: BatchedVideoDatapoint, return_dict=False + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = backbone_out["backbone_fpn"] is not None + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + for stage_id in processing_order: + # Get the image features for the current frames + # img_ids = input.find_inputs[stage_id].img_ids + img_ids = input.flat_obj_to_img_idx[stage_id] + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids] for x in vision_feats] + current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + ( + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features_per_frame( + input.flat_img_batch, img_ids + ) + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + current_out, sam_outputs, high_res_features, pix_feat = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _iter_correct_pt_sampling( + self, + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat_with_mem, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ): + + assert gt_masks is not None + all_pred_masks = [low_res_masks] + all_pred_high_res_masks = [high_res_masks] + all_pred_multimasks = [low_res_multimasks] + all_pred_high_res_multimasks = [high_res_multimasks] + all_pred_ious = [ious] + all_point_inputs = [point_inputs] + all_object_score_logits = [object_score_logits] + for _ in range(self.num_correction_pt_per_frame): + # sample a new point from the error between prediction and ground-truth + # (with a small probability, directly sample from GT masks instead of errors) + if self.training and self.prob_to_sample_from_gt_for_train > 0: + sample_from_gt = ( + self.rng.random() < self.prob_to_sample_from_gt_for_train + ) + else: + sample_from_gt = False + # if `pred_for_new_pt` is None, only GT masks will be used for point sampling + pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) + new_points, new_labels = get_next_point( + gt_masks=gt_masks, + pred_masks=pred_for_new_pt, + method="uniform" if self.training else self.pt_sampling_for_eval, + ) + point_inputs = concat_points(point_inputs, new_points, new_labels) + # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. + # For tracking, this means that when the user adds a correction click, we also feed + # the tracking output mask logits along with the click as input to the SAM decoder. + mask_inputs = low_res_masks + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: + sam_outputs = torch.utils.checkpoint.checkpoint( + self._forward_sam_heads, + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + use_reentrant=False, + ) + else: + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + _, + object_score_logits, + ) = sam_outputs + all_pred_masks.append(low_res_masks) + all_pred_high_res_masks.append(high_res_masks) + all_pred_multimasks.append(low_res_multimasks) + all_pred_high_res_multimasks.append(high_res_multimasks) + all_pred_ious.append(ious) + all_point_inputs.append(point_inputs) + all_object_score_logits.append(object_score_logits) + + # Concatenate the masks along channel (to compute losses on all of them, + # using `MultiStepIteractiveMasks`) + current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) + current_out["multistep_pred_masks_high_res"] = torch.cat( + all_pred_high_res_masks, dim=1 + ) + current_out["multistep_pred_multimasks"] = all_pred_multimasks + current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks + current_out["multistep_pred_ious"] = all_pred_ious + current_out["multistep_point_inputs"] = all_point_inputs + current_out["multistep_object_score_logits"] = all_object_score_logits + + return point_inputs, sam_outputs diff --git a/monst3r/third_party/sam2/training/optimizer.py b/monst3r/third_party/sam2/training/optimizer.py new file mode 100644 index 0000000..ae15966 --- /dev/null +++ b/monst3r/third_party/sam2/training/optimizer.py @@ -0,0 +1,502 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import fnmatch +import inspect +import itertools +import logging +import types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) + +import hydra + +import torch +import torch.nn as nn +from omegaconf import DictConfig +from torch import Tensor + + +class Optimizer: + def __init__(self, optimizer, schedulers=None) -> None: + self.optimizer = optimizer + self.schedulers = schedulers + self._validate_optimizer_schedulers() + self.step_schedulers(0.0, 0) + + def _validate_optimizer_schedulers(self): + if self.schedulers is None: + return + for _, set_of_schedulers in enumerate(self.schedulers): + for option, _ in set_of_schedulers.items(): + assert option in self.optimizer.defaults, ( + "Optimizer option " + f"{option} not found in {self.optimizer}. Valid options are " + f"{self.optimizer.defaults.keys()}" + ) + + def step_schedulers(self, where: float, step: int) -> None: + if self.schedulers is None: + return + for i, param_group in enumerate(self.optimizer.param_groups): + for option, scheduler in self.schedulers[i].items(): + if "step" in inspect.signature(scheduler.__call__).parameters: + new_value = scheduler(step=step, where=where) + elif ( + hasattr(scheduler, "scheduler") + and "step" + in inspect.signature(scheduler.scheduler.__call__).parameters + ): + # To handle ValueScaler wrappers + new_value = scheduler(step=step, where=where) + else: + new_value = scheduler(where) + param_group[option] = new_value + + def step(self, where, step, closure=None): + self.step_schedulers(where, step) + return self.optimizer.step(closure) + + def zero_grad(self, *args, **kwargs): + return self.optimizer.zero_grad(*args, **kwargs) + + +def set_default_parameters( + scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str] +) -> None: + """Set up the "default" scheduler with the right parameters. + + Args: + scheduler_cgfs: A list of scheduler configs, where each scheduler also + specifies which parameters it applies to, based on the names of parameters + or the class of the modules. At most one scheduler is allowed to skip this + specification, which is used as a "default" specification for any remaining + parameters. + all_parameter_names: Names of all the parameters to consider. + """ + constraints = [ + scheduler_cfg.parameter_names + for scheduler_cfg in scheduler_cfgs + if scheduler_cfg.parameter_names is not None + ] + if len(constraints) == 0: + default_params = set(all_parameter_names) + else: + default_params = all_parameter_names - set.union(*constraints) + default_count = 0 + for scheduler_cfg in scheduler_cfgs: + if scheduler_cfg.parameter_names is None: + scheduler_cfg.parameter_names = default_params + default_count += 1 + assert default_count <= 1, "Only one scheduler per option can be default" + if default_count == 0: + # No default scheduler specified, add a default, but without any scheduler + # for that option + scheduler_cfgs.append({"parameter_names": default_params}) + + +def name_constraints_to_parameters( + param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor] +) -> List[torch.nn.Parameter]: + """Return parameters which match the intersection of parameter constraints. + + Note that this returns the parameters themselves, not their names. + + Args: + param_constraints: A list, with each element being a set of allowed parameters. + named_parameters: Mapping from a parameter name to the parameter itself. + + Returns: + A list containing the parameters which overlap with _each_ constraint set from + param_constraints. + """ + matching_names = set.intersection(*param_constraints) + return [value for name, value in named_parameters.items() if name in matching_names] + + +def map_scheduler_cfgs_to_param_groups( + all_scheduler_cfgs: Iterable[List[Dict]], + named_parameters: Dict[str, Tensor], +) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]: + """Produce parameter groups corresponding to all the scheduler configs. + + Takes all the scheduler configs, each of which applies to a specific optimizer + option (like "lr" or "weight_decay") and has a set of parameter names which it + applies to, and produces a final set of param groups where each param group + covers all the options which apply to a particular set of parameters. + + Args: + all_scheduler_cfgs: All the scheduler configs covering every option. + named_parameters: Mapping from a parameter name to the parameter itself. + Returns: + Tuple of lists of schedulers and param_groups, where schedulers[i] + applies to param_groups[i]. + """ + + scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs) + schedulers = [] + param_groups = [] + for scheduler_cfgs in scheduler_cfgs_per_param_group: + param_constraints = [ + scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs + ] + matching_parameters = name_constraints_to_parameters( + param_constraints, named_parameters + ) + if len(matching_parameters) == 0: # If no overlap of parameters, skip + continue + schedulers_for_group = { + scheduler_cfg["option"]: scheduler_cfg["scheduler"] + for scheduler_cfg in scheduler_cfgs + if "option" in scheduler_cfg + } + schedulers.append(schedulers_for_group) + param_groups.append({"params": matching_parameters}) + return schedulers, param_groups + + +def validate_param_group_params(param_groups: List[Dict], model: nn.Module): + """Check that the param groups are non-overlapping and cover all the parameters. + + Args: + param_groups: List of all param groups + model: Model to validate against. The check ensures that all the model + parameters are part of param_groups + """ + for pg in param_groups: + # no param should be repeated within a group + assert len(pg["params"]) == len(set(pg["params"])) + parameters = [set(param_group["params"]) for param_group in param_groups] + model_parameters = {parameter for _, parameter in model.named_parameters()} + for p1, p2 in itertools.permutations(parameters, 2): + assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint" + assert set.union(*parameters) == model_parameters, ( + "Scheduler generated param_groups must include all parameters of the model." + f" Found {len(set.union(*parameters))} params whereas model has" + f" {len(model_parameters)} params" + ) + + +def unix_module_cls_pattern_to_parameter_names( + filter_module_cls_names: List[str], + module_cls_to_param_names: Dict[Type, str], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in filter_module_cls_names. + + Args: + filter_module_cls_names: A list of filter strings containing class names, like + ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"] + module_cls_to_param_names: Mapping from module classes to the parameter names + they contain. See `get_module_cls_to_param_names`. + """ + if filter_module_cls_names is None: + return set() + allowed_parameter_names = [] + for module_cls_name in filter_module_cls_names: + module_cls = hydra.utils.get_class(module_cls_name) + if module_cls not in module_cls_to_param_names: + raise AssertionError( + f"module_cls_name {module_cls_name} does not " + "match any classes in the model" + ) + matching_parameters = module_cls_to_param_names[module_cls] + assert ( + len(matching_parameters) > 0 + ), f"module_cls_name {module_cls_name} does not contain any parameters in the model" + logging.info( + f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} " + ) + allowed_parameter_names.append(matching_parameters) + return set.union(*allowed_parameter_names) + + +def unix_param_pattern_to_parameter_names( + filter_param_names: Optional[List[str]], + parameter_names: Dict[str, torch.Tensor], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in filter_param_names. + + Args: + filter_param_names: A list of unix-style filter strings with optional + wildcards, like ["block.2.*", "block.2.linear.weight"] + module_cls_to_param_names: Mapping from module classes to the parameter names + they contain. See `get_module_cls_to_param_names`. + """ + + if filter_param_names is None: + return set() + allowed_parameter_names = [] + for param_name in filter_param_names: + matching_parameters = set(fnmatch.filter(parameter_names, param_name)) + assert ( + len(matching_parameters) >= 1 + ), f"param_name {param_name} does not match any parameters in the model" + logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}") + allowed_parameter_names.append(matching_parameters) + return set.union(*allowed_parameter_names) + + +def _unix_pattern_to_parameter_names( + scheduler_cfg: DictConfig, + parameter_names: Set[str], + module_cls_to_param_names: Dict[Type, str], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in scheduler_cfg. + + Args: + scheduler_cfg: The config for the scheduler + parameter_names: The set of all parameter names which will be filtered + """ + if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg: + return None + return unix_param_pattern_to_parameter_names( + scheduler_cfg.get("param_names"), parameter_names + ).union( + unix_module_cls_pattern_to_parameter_names( + scheduler_cfg.get("module_cls_names"), module_cls_to_param_names + ) + ) + + +def get_module_cls_to_param_names( + model: nn.Module, param_allowlist: Set[str] = None +) -> Dict[Type, str]: + """Produce a mapping from all the modules classes to the names of parames they own. + + Only counts a parameter as part of the immediate parent module, i.e. recursive + parents do not count. + + Args: + model: Model to iterate over + param_allowlist: If specified, only these param names will be processed + """ + + module_cls_to_params = {} + for module_name, module in model.named_modules(): + module_cls = type(module) + module_cls_to_params.setdefault(module_cls, set()) + for param_name, _ in module.named_parameters(recurse=False): + full_param_name = get_full_parameter_name(module_name, param_name) + if param_allowlist is None or full_param_name in param_allowlist: + module_cls_to_params[module_cls].add(full_param_name) + return module_cls_to_params + + +def construct_optimizer( + model: torch.nn.Module, + optimizer_conf: Any, + options_conf: Mapping[str, List] = None, + param_group_modifiers_conf: List[Callable] = None, + param_allowlist: Optional[Set[str]] = None, + validate_param_groups=True, +) -> Optimizer: + """ + Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer + with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay + Batchnorm and/or no-update 1-D parameters support, based on the config. + + Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling + (LARS): https://arxiv.org/abs/1708.03888 + + Args: + model: model to perform stochastic gradient descent + optimization or ADAM optimization. + optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or + ADAM, still missing the params argument which this function provides to + produce the final optimizer + param_group_modifiers_conf: Optional user specified functions which can modify + the final scheduler configs before the optimizer's param groups are built + param_allowlist: The parameters to optimize. Parameters which are not part of + this allowlist will be skipped. + validate_param_groups: If enabled, valides that the produced param_groups don't + overlap and cover all the model parameters. + """ + if param_allowlist is None: + param_allowlist = {name for name, _ in model.named_parameters()} + + named_parameters = { + name: param + for name, param in model.named_parameters() + if name in param_allowlist + } + + if not options_conf: + optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values()) + return Optimizer(optimizer) + + all_parameter_names = { + name for name, _ in model.named_parameters() if name in param_allowlist + } + module_cls_to_all_param_names = get_module_cls_to_param_names( + model, param_allowlist + ) + + scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf) + all_scheduler_cfgs = [] + for option, scheduler_cfgs in scheduler_cfgs_per_option.items(): + for config in scheduler_cfgs: + config.option = option + config.parameter_names = _unix_pattern_to_parameter_names( + config, all_parameter_names, module_cls_to_all_param_names + ) + set_default_parameters(scheduler_cfgs, all_parameter_names) + all_scheduler_cfgs.append(scheduler_cfgs) + + if param_group_modifiers_conf: + for custom_param_modifier in param_group_modifiers_conf: + custom_param_modifier = hydra.utils.instantiate(custom_param_modifier) + all_scheduler_cfgs = custom_param_modifier( + scheduler_cfgs=all_scheduler_cfgs, model=model + ) + schedulers, param_groups = map_scheduler_cfgs_to_param_groups( + all_scheduler_cfgs, named_parameters + ) + if validate_param_groups: + validate_param_group_params(param_groups, model) + optimizer = hydra.utils.instantiate(optimizer_conf, param_groups) + return Optimizer(optimizer, schedulers) + + +def get_full_parameter_name(module_name, param_name): + if module_name == "": + return param_name + return f"{module_name}.{param_name}" + + +class GradientClipper: + """ + Gradient clipping utils that works for DDP + """ + + def __init__(self, max_norm: float = 1.0, norm_type: int = 2): + assert isinstance(max_norm, (int, float)) or max_norm is None + self.max_norm = max_norm if max_norm is None else float(max_norm) + self.norm_type = norm_type + + def __call__(self, model: nn.Module): + if self.max_norm is None: + return # no-op + + nn.utils.clip_grad_norm_( + model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type + ) + + +class ValueScaler: + def __init__(self, scheduler, mult_val: float): + self.scheduler = scheduler + self.mult_val = mult_val + + def __call__(self, *args, **kwargs): + val = self.scheduler(*args, **kwargs) + return val * self.mult_val + + +def rgetattr(obj, rattrs: str = None): + """ + Like getattr(), but supports dotted notation for nested objects. + rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2 + """ + if rattrs is None: + return obj + attrs = rattrs.split(".") + for attr in attrs: + obj = getattr(obj, attr) + return obj + + +def layer_decay_param_modifier( + scheduler_cfgs: List[List[Dict]], + model, + layer_decay_value: float, + layer_decay_min: Optional[float] = None, + apply_to: Optional[str] = None, + overrides: List[Dict] = (), +) -> List[List[Dict]]: + """ + Args + - scheduler_cfgs: a list of omegaconf.ListConfigs. + Each element in the list is a omegaconfg.DictConfig with the following structure + { + "scheduler": + "option": possible options are "lr", "weight_decay" etc. + "parameter_names": Set of str indicating param names that this scheduler applies to + } + - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and + and a method get_num_layers. + Alternatively, use apply_to argument to select a specific component of the model. + - layer_decay_value: float + - layer_decay_min: min val for layer decay + - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to + - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value". + Returns + - scheduler_configs: same structure as the input, elements can be modified + """ + model = rgetattr(model, apply_to) + num_layers = model.get_num_layers() + 1 + layer_decays = [ + layer_decay_value ** (num_layers - i) for i in range(num_layers + 1) + ] + if layer_decay_min is not None: + layer_decays = [max(val, layer_decay_min) for val in layer_decays] + final_scheduler_cfgs = [] + # scheduler_cfgs is a list of lists + for scheduler_cfg_group in scheduler_cfgs: + curr_cfg_group = [] + # scheduler_cfg_group is a list of dictionaries + for scheduler_cfg in scheduler_cfg_group: + if scheduler_cfg["option"] != "lr": + curr_cfg_group.append(scheduler_cfg) + continue + # Need sorted so that the list of parameter names is deterministic and consistent + # across re-runs of this job. Else it was causing issues with loading the optimizer + # state during a job restart (D38591759) + parameter_names = sorted(scheduler_cfg["parameter_names"]) + + # Only want one cfg group per layer + layer_cfg_groups = {} + for param_name in parameter_names: + layer_id = num_layers + this_scale = layer_decays[layer_id] + if param_name.startswith(apply_to): + layer_id = model.get_layer_id(param_name) + this_scale = layer_decays[layer_id] + # Overrides + for override in overrides: + if fnmatch.fnmatchcase(param_name, override["pattern"]): + this_scale = float(override["value"]) + layer_id = override["pattern"] + break + + if layer_id not in layer_cfg_groups: + curr_param = { + "option": scheduler_cfg["option"], + "scheduler": ValueScaler( + scheduler_cfg["scheduler"], this_scale + ), + "parameter_names": {param_name}, + } + else: + curr_param = layer_cfg_groups[layer_id] + curr_param["parameter_names"].add(param_name) + layer_cfg_groups[layer_id] = curr_param + + for layer_cfg in layer_cfg_groups.values(): + curr_cfg_group.append(layer_cfg) + + final_scheduler_cfgs.append(curr_cfg_group) + return final_scheduler_cfgs diff --git a/monst3r/third_party/sam2/training/scripts/sav_frame_extraction_submitit.py b/monst3r/third_party/sam2/training/scripts/sav_frame_extraction_submitit.py new file mode 100644 index 0000000..9d5ed2f --- /dev/null +++ b/monst3r/third_party/sam2/training/scripts/sav_frame_extraction_submitit.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import argparse +import os +from pathlib import Path + +import cv2 + +import numpy as np +import submitit +import tqdm + + +def get_args_parser(): + parser = argparse.ArgumentParser( + description="[SA-V Preprocessing] Extracting JPEG frames", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # ------------ + # DATA + # ------------ + data_parser = parser.add_argument_group( + title="SA-V dataset data root", + description="What data to load and how to process it.", + ) + data_parser.add_argument( + "--sav-vid-dir", + type=str, + required=True, + help=("Where to find the SAV videos"), + ) + data_parser.add_argument( + "--sav-frame-sample-rate", + type=int, + default=4, + help="Rate at which to sub-sample frames", + ) + + # ------------ + # LAUNCH + # ------------ + launch_parser = parser.add_argument_group( + title="Cluster launch settings", + description="Number of jobs and retry settings.", + ) + launch_parser.add_argument( + "--n-jobs", + type=int, + required=True, + help="Shard the run over this many jobs.", + ) + launch_parser.add_argument( + "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes." + ) + launch_parser.add_argument( + "--partition", type=str, required=True, help="Partition to launch on." + ) + launch_parser.add_argument( + "--account", type=str, required=True, help="Partition to launch on." + ) + launch_parser.add_argument("--qos", type=str, required=True, help="QOS.") + + # ------------ + # OUTPUT + # ------------ + output_parser = parser.add_argument_group( + title="Setting for results output", description="Where and how to save results." + ) + output_parser.add_argument( + "--output-dir", + type=str, + required=True, + help=("Where to dump the extracted jpeg frames"), + ) + output_parser.add_argument( + "--slurm-output-root-dir", + type=str, + required=True, + help=("Where to save slurm outputs"), + ) + return parser + + +def decode_video(video_path: str): + assert os.path.exists(video_path) + video = cv2.VideoCapture(video_path) + video_frames = [] + while video.isOpened(): + ret, frame = video.read() + if ret: + video_frames.append(frame) + else: + break + return video_frames + + +def extract_frames(video_path, sample_rate): + frames = decode_video(video_path) + return frames[::sample_rate] + + +def submitit_launch(video_paths, sample_rate, save_root): + for path in tqdm.tqdm(video_paths): + frames = extract_frames(path, sample_rate) + output_folder = os.path.join(save_root, Path(path).stem) + if not os.path.exists(output_folder): + os.makedirs(output_folder) + for fid, frame in enumerate(frames): + frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg") + cv2.imwrite(frame_path, frame) + print(f"Saved output to {save_root}") + + +if __name__ == "__main__": + parser = get_args_parser() + args = parser.parse_args() + + sav_vid_dir = args.sav_vid_dir + save_root = args.output_dir + sample_rate = args.sav_frame_sample_rate + + # List all SA-V videos + mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")]) + mp4_files = np.array(mp4_files) + chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)] + + print(f"Processing videos in: {sav_vid_dir}") + print(f"Processing {len(mp4_files)} files") + print(f"Beginning processing in {args.n_jobs} processes") + + # Submitit params + jobs_dir = os.path.join(args.slurm_output_root_dir, "%j") + cpus_per_task = 4 + executor = submitit.AutoExecutor(folder=jobs_dir) + executor.update_parameters( + timeout_min=args.timeout, + gpus_per_node=0, + tasks_per_node=1, + slurm_array_parallelism=args.n_jobs, + cpus_per_task=cpus_per_task, + slurm_partition=args.partition, + slurm_account=args.account, + slurm_qos=args.qos, + ) + executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"]) + + # Launch + jobs = [] + with executor.batch(): + for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)): + job = executor.submit( + submitit_launch, + video_paths=mp4_chunk, + sample_rate=sample_rate, + save_root=save_root, + ) + jobs.append(job) + + for j in jobs: + print(f"Slurm JobID: {j.job_id}") + print(f"Saving outputs to {save_root}") + print(f"Slurm outputs at {args.slurm_output_root_dir}") diff --git a/monst3r/third_party/sam2/training/train.py b/monst3r/third_party/sam2/training/train.py new file mode 100644 index 0000000..db06123 --- /dev/null +++ b/monst3r/third_party/sam2/training/train.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import sys +import traceback +from argparse import ArgumentParser + +import submitit +import torch + +from hydra import compose, initialize_config_module +from hydra.utils import instantiate + +from iopath.common.file_io import g_pathmgr +from omegaconf import OmegaConf + +from training.utils.train_utils import makedir, register_omegaconf_resolvers + +os.environ["HYDRA_FULL_ERROR"] = "1" + + +def single_proc_run(local_rank, main_port, cfg, world_size): + """Single GPU process""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(main_port) + os.environ["RANK"] = str(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + try: + register_omegaconf_resolvers() + except Exception as e: + logging.info(e) + + trainer = instantiate(cfg.trainer, _recursive_=False) + trainer.run() + + +def single_node_runner(cfg, main_port: int): + assert cfg.launcher.num_nodes == 1 + num_proc = cfg.launcher.gpus_per_node + torch.multiprocessing.set_start_method( + "spawn" + ) # CUDA runtime does not support `fork` + if num_proc == 1: + # directly call single_proc so we can easily set breakpoints + # mp.spawn does not let us set breakpoints + single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc) + else: + mp_runner = torch.multiprocessing.start_processes + args = (main_port, cfg, num_proc) + # Note: using "fork" below, "spawn" causes time and error regressions. Using + # spawn changes the default multiprocessing context to spawn, which doesn't + # interact well with the dataloaders (likely due to the use of OpenCV). + mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn") + + +def format_exception(e: Exception, limit=20): + traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit)) + return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}" + + +class SubmititRunner(submitit.helpers.Checkpointable): + """A callable which is passed to submitit to launch the jobs.""" + + def __init__(self, port, cfg): + self.cfg = cfg + self.port = port + self.has_setup = False + + def run_trainer(self): + job_env = submitit.JobEnvironment() + # Need to add this again so the hydra.job.set_env PYTHONPATH + # is also set when launching jobs. + add_pythonpath_to_sys_path() + os.environ["MASTER_ADDR"] = job_env.hostnames[0] + os.environ["MASTER_PORT"] = str(self.port) + os.environ["RANK"] = str(job_env.global_rank) + os.environ["LOCAL_RANK"] = str(job_env.local_rank) + os.environ["WORLD_SIZE"] = str(job_env.num_tasks) + + register_omegaconf_resolvers() + cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False) + cfg_resolved = OmegaConf.create(cfg_resolved) + + trainer = instantiate(cfg_resolved.trainer, _recursive_=False) + trainer.run() + + def __call__(self): + job_env = submitit.JobEnvironment() + self.setup_job_info(job_env.job_id, job_env.global_rank) + try: + self.run_trainer() + except Exception as e: + # Log the exception. Then raise it again (as what SubmititRunner currently does). + message = format_exception(e) + logging.error(message) + raise e + + def setup_job_info(self, job_id, rank): + """Set up slurm job info""" + self.job_info = { + "job_id": job_id, + "rank": rank, + "cluster": self.cfg.get("cluster", None), + "experiment_log_dir": self.cfg.launcher.experiment_log_dir, + } + + self.has_setup = True + + +def add_pythonpath_to_sys_path(): + if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]: + return + sys.path = os.environ["PYTHONPATH"].split(":") + sys.path + + +def main(args) -> None: + cfg = compose(config_name=args.config) + if cfg.launcher.experiment_log_dir is None: + cfg.launcher.experiment_log_dir = os.path.join( + os.getcwd(), "sam2_logs", args.config + ) + print("###################### Train App Config ####################") + print(OmegaConf.to_yaml(cfg)) + print("############################################################") + + add_pythonpath_to_sys_path() + makedir(cfg.launcher.experiment_log_dir) + with g_pathmgr.open( + os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w" + ) as f: + f.write(OmegaConf.to_yaml(cfg)) + + cfg_resolved = OmegaConf.to_container(cfg, resolve=False) + cfg_resolved = OmegaConf.create(cfg_resolved) + + with g_pathmgr.open( + os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w" + ) as f: + f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True)) + + submitit_conf = cfg.get("submitit", None) + assert submitit_conf is not None, "Missing submitit config" + + submitit_dir = cfg.launcher.experiment_log_dir + submitit_dir = os.path.join(submitit_dir, "submitit_logs") + # Priotrize cmd line args + cfg.launcher.gpus_per_node = ( + args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node + ) + cfg.launcher.num_nodes = ( + args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes + ) + submitit_conf.use_cluster = ( + args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster + ) + if submitit_conf.use_cluster: + executor = submitit.AutoExecutor(folder=submitit_dir) + submitit_conf.partition = ( + args.partition + if args.partition is not None + else submitit_conf.get("partition", None) + ) + submitit_conf.account = ( + args.account + if args.account is not None + else submitit_conf.get("account", None) + ) + submitit_conf.qos = ( + args.qos if args.qos is not None else submitit_conf.get("qos", None) + ) + job_kwargs = { + "timeout_min": 60 * submitit_conf.timeout_hour, + "name": ( + submitit_conf.name if hasattr(submitit_conf, "name") else args.config + ), + "slurm_partition": submitit_conf.partition, + "gpus_per_node": cfg.launcher.gpus_per_node, + "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU + "cpus_per_task": submitit_conf.cpus_per_task, + "nodes": cfg.launcher.num_nodes, + "slurm_additional_parameters": { + "exclude": " ".join(submitit_conf.get("exclude_nodes", [])), + }, + } + if "include_nodes" in submitit_conf: + assert ( + len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes + ), "Not enough nodes" + job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join( + submitit_conf["include_nodes"] + ) + if submitit_conf.account is not None: + job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account + if submitit_conf.qos is not None: + job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos + + if submitit_conf.get("mem_gb", None) is not None: + job_kwargs["mem_gb"] = submitit_conf.mem_gb + elif submitit_conf.get("mem", None) is not None: + job_kwargs["slurm_mem"] = submitit_conf.mem + + if submitit_conf.get("constraints", None) is not None: + job_kwargs["slurm_constraint"] = submitit_conf.constraints + + if submitit_conf.get("comment", None) is not None: + job_kwargs["slurm_comment"] = submitit_conf.comment + + # Supports only cpu-bind option within srun_args. New options can be added here + if submitit_conf.get("srun_args", None) is not None: + job_kwargs["slurm_srun_args"] = [] + if submitit_conf.srun_args.get("cpu_bind", None) is not None: + job_kwargs["slurm_srun_args"].extend( + ["--cpu-bind", submitit_conf.srun_args.cpu_bind] + ) + + print("###################### SLURM Config ####################") + print(job_kwargs) + print("##########################################") + executor.update_parameters(**job_kwargs) + + main_port = random.randint( + submitit_conf.port_range[0], submitit_conf.port_range[1] + ) + runner = SubmititRunner(main_port, cfg) + job = executor.submit(runner) + print(f"Submitit Job ID: {job.job_id}") + runner.setup_job_info(job.job_id, rank=0) + else: + cfg.launcher.num_nodes = 1 + main_port = random.randint( + submitit_conf.port_range[0], submitit_conf.port_range[1] + ) + single_node_runner(cfg, main_port) + + +if __name__ == "__main__": + + initialize_config_module("sam2", version_base="1.2") + parser = ArgumentParser() + parser.add_argument( + "-c", + "--config", + required=True, + type=str, + help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)", + ) + parser.add_argument( + "--use-cluster", + type=int, + default=None, + help="whether to launch on a cluster, 0: run locally, 1: run on a cluster", + ) + parser.add_argument("--partition", type=str, default=None, help="SLURM partition") + parser.add_argument("--account", type=str, default=None, help="SLURM account") + parser.add_argument("--qos", type=str, default=None, help="SLURM qos") + parser.add_argument( + "--num-gpus", type=int, default=None, help="number of GPUS per node" + ) + parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes") + args = parser.parse_args() + args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None + register_omegaconf_resolvers() + main(args) diff --git a/monst3r/third_party/sam2/training/trainer.py b/monst3r/third_party/sam2/training/trainer.py new file mode 100644 index 0000000..2b7c27b --- /dev/null +++ b/monst3r/third_party/sam2/training/trainer.py @@ -0,0 +1,1113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import json +import logging +import math +import os +import time +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Optional + +import numpy as np + +import torch +import torch.distributed as dist +import torch.nn as nn +from hydra.utils import instantiate +from iopath.common.file_io import g_pathmgr + +from training.optimizer import construct_optimizer + +from training.utils.checkpoint_utils import ( + assert_skipped_parameters_are_frozen, + exclude_params_matching_unix_pattern, + load_state_dict_into_model, + with_check_parameter_frozen, +) +from training.utils.data_utils import BatchedVideoDatapoint +from training.utils.distributed import all_reduce_max, barrier, get_rank + +from training.utils.logger import Logger, setup_logging + +from training.utils.train_utils import ( + AverageMeter, + collect_dict_keys, + DurationMeter, + get_amp_type, + get_machine_local_and_dist_rank, + get_resume_checkpoint, + human_readable_time, + is_dist_avail_and_initialized, + log_env_variables, + makedir, + MemMeter, + Phase, + ProgressMeter, + set_seeds, + setup_distributed_backend, +) + + +CORE_LOSS_KEY = "core_loss" + + +def unwrap_ddp_if_wrapped(model): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + return model + + +@dataclass +class OptimAMPConf: + enabled: bool = False + amp_dtype: str = "float16" + + +@dataclass +class OptimConf: + optimizer: torch.optim.Optimizer = None + options: Optional[Dict[str, Any]] = None + param_group_modifiers: Optional[List] = None + amp: Optional[Dict[str, Any]] = None + gradient_clip: Any = None + gradient_logger: Any = None + + def __post_init__(self): + # amp + if not isinstance(self.amp, OptimAMPConf): + if self.amp is None: + self.amp = {} + assert isinstance(self.amp, Mapping) + self.amp = OptimAMPConf(**self.amp) + + +@dataclass +class DistributedConf: + backend: Optional[str] = None # inferred from accelerator type + comms_dtype: Optional[str] = None + find_unused_parameters: bool = False + timeout_mins: int = 30 + + +@dataclass +class CudaConf: + cudnn_deterministic: bool = False + cudnn_benchmark: bool = True + allow_tf32: bool = False + # if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul + matmul_allow_tf32: Optional[bool] = None + # if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn + cudnn_allow_tf32: Optional[bool] = None + + +@dataclass +class CheckpointConf: + save_dir: str + save_freq: int + save_list: List[int] = field(default_factory=list) + model_weight_initializer: Any = None + save_best_meters: List[str] = None + skip_saving_parameters: List[str] = field(default_factory=list) + initialize_after_preemption: Optional[bool] = None + # if not None, training will be resumed from this checkpoint + resume_from: Optional[str] = None + + def infer_missing(self): + if self.initialize_after_preemption is None: + with_skip_saving = len(self.skip_saving_parameters) > 0 + self.initialize_after_preemption = with_skip_saving + return self + + +@dataclass +class LoggingConf: + log_dir: str + log_freq: int # In iterations + tensorboard_writer: Any + log_level_primary: str = "INFO" + log_level_secondary: str = "ERROR" + log_scalar_frequency: int = 100 + log_visual_frequency: int = 100 + scalar_keys_to_log: Optional[Dict[str, Any]] = None + log_batch_stats: bool = False + + +class Trainer: + """ + Trainer supporting the DDP training strategies. + """ + + EPSILON = 1e-8 + + def __init__( + self, + *, # the order of these args can change at any time, so they are keyword-only + data: Dict[str, Any], + model: Dict[str, Any], + logging: Dict[str, Any], + checkpoint: Dict[str, Any], + max_epochs: int, + mode: str = "train", + accelerator: str = "cuda", + seed_value: int = 123, + val_epoch_freq: int = 1, + distributed: Dict[str, bool] = None, + cuda: Dict[str, bool] = None, + env_variables: Optional[Dict[str, Any]] = None, + optim: Optional[Dict[str, Any]] = None, + optim_overrides: Optional[List[Dict[str, Any]]] = None, + meters: Optional[Dict[str, Any]] = None, + loss: Optional[Dict[str, Any]] = None, + ): + + self._setup_env_variables(env_variables) + self._setup_timers() + + self.data_conf = data + self.model_conf = model + self.logging_conf = LoggingConf(**logging) + self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing() + self.max_epochs = max_epochs + self.mode = mode + self.val_epoch_freq = val_epoch_freq + self.optim_conf = OptimConf(**optim) if optim is not None else None + self.meters_conf = meters + self.loss_conf = loss + distributed = DistributedConf(**distributed or {}) + cuda = CudaConf(**cuda or {}) + self.where = 0.0 + + self._infer_distributed_backend_if_none(distributed, accelerator) + + self._setup_device(accelerator) + + self._setup_torch_dist_and_backend(cuda, distributed) + + makedir(self.logging_conf.log_dir) + setup_logging( + __name__, + output_dir=self.logging_conf.log_dir, + rank=self.rank, + log_level_primary=self.logging_conf.log_level_primary, + log_level_secondary=self.logging_conf.log_level_secondary, + ) + + set_seeds(seed_value, self.max_epochs, self.distributed_rank) + log_env_variables() + + assert ( + is_dist_avail_and_initialized() + ), "Torch distributed needs to be initialized before calling the trainer." + + self._setup_components() # Except Optimizer everything is setup here. + self._move_to_device() + self._construct_optimizers() + self._setup_dataloaders() + + self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f") + + if self.checkpoint_conf.resume_from is not None: + assert os.path.exists( + self.checkpoint_conf.resume_from + ), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!" + dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt") + if self.distributed_rank == 0 and not os.path.exists(dst): + # Copy the "resume_from" checkpoint to the checkpoint folder + # if there is not a checkpoint to resume from already there + makedir(self.checkpoint_conf.save_dir) + g_pathmgr.copy(self.checkpoint_conf.resume_from, dst) + barrier() + + self.load_checkpoint() + self._setup_ddp_distributed_training(distributed, accelerator) + barrier() + + def _setup_timers(self): + """ + Initializes counters for elapsed time and eta. + """ + self.start_time = time.time() + self.ckpt_time_elapsed = 0 + self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0) + + def _get_meters(self, phase_filters=None): + if self.meters is None: + return {} + meters = {} + for phase, phase_meters in self.meters.items(): + if phase_filters is not None and phase not in phase_filters: + continue + for key, key_meters in phase_meters.items(): + if key_meters is None: + continue + for name, meter in key_meters.items(): + meters[f"{phase}_{key}/{name}"] = meter + return meters + + def _infer_distributed_backend_if_none(self, distributed_conf, accelerator): + if distributed_conf.backend is None: + distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo" + + def _setup_env_variables(self, env_variables_conf) -> None: + if env_variables_conf is not None: + for variable_name, value in env_variables_conf.items(): + os.environ[variable_name] = value + + def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None: + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic + torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark + torch.backends.cuda.matmul.allow_tf32 = ( + cuda_conf.matmul_allow_tf32 + if cuda_conf.matmul_allow_tf32 is not None + else cuda_conf.allow_tf32 + ) + torch.backends.cudnn.allow_tf32 = ( + cuda_conf.cudnn_allow_tf32 + if cuda_conf.cudnn_allow_tf32 is not None + else cuda_conf.allow_tf32 + ) + + self.rank = setup_distributed_backend( + distributed_conf.backend, distributed_conf.timeout_mins + ) + + def _setup_device(self, accelerator): + self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank() + if accelerator == "cuda": + self.device = torch.device("cuda", self.local_rank) + torch.cuda.set_device(self.local_rank) + elif accelerator == "cpu": + self.device = torch.device("cpu") + else: + raise ValueError(f"Unsupported accelerator: {accelerator}") + + def _setup_ddp_distributed_training(self, distributed_conf, accelerator): + + assert isinstance(self.model, torch.nn.Module) + + self.model = nn.parallel.DistributedDataParallel( + self.model, + device_ids=[self.local_rank] if accelerator == "cuda" else [], + find_unused_parameters=distributed_conf.find_unused_parameters, + ) + if distributed_conf.comms_dtype is not None: # noqa + from torch.distributed.algorithms import ddp_comm_hooks + + amp_type = get_amp_type(distributed_conf.comms_dtype) + if amp_type == torch.bfloat16: + hook = ddp_comm_hooks.default_hooks.bf16_compress_hook + logging.info("Enabling bfloat16 grad communication") + else: + hook = ddp_comm_hooks.default_hooks.fp16_compress_hook + logging.info("Enabling fp16 grad communication") + process_group = None + self.model.register_comm_hook(process_group, hook) + + def _move_to_device(self): + logging.info( + f"Moving components to device {self.device} and local rank {self.local_rank}." + ) + + self.model.to(self.device) + + logging.info( + f"Done moving components to device {self.device} and local rank {self.local_rank}." + ) + + def save_checkpoint(self, epoch, checkpoint_names=None): + checkpoint_folder = self.checkpoint_conf.save_dir + makedir(checkpoint_folder) + if checkpoint_names is None: + checkpoint_names = ["checkpoint"] + if ( + self.checkpoint_conf.save_freq > 0 + and (int(epoch) % self.checkpoint_conf.save_freq == 0) + ) or int(epoch) in self.checkpoint_conf.save_list: + checkpoint_names.append(f"checkpoint_{int(epoch)}") + + checkpoint_paths = [] + for ckpt_name in checkpoint_names: + checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt")) + + state_dict = unwrap_ddp_if_wrapped(self.model).state_dict() + state_dict = exclude_params_matching_unix_pattern( + patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict + ) + + checkpoint = { + "model": state_dict, + "optimizer": self.optim.optimizer.state_dict(), + "epoch": epoch, + "loss": self.loss.state_dict(), + "steps": self.steps, + "time_elapsed": self.time_elapsed_meter.val, + "best_meter_values": self.best_meter_values, + } + if self.optim_conf.amp.enabled: + checkpoint["scaler"] = self.scaler.state_dict() + + # DDP checkpoints are only saved on rank 0 (all workers are identical) + if self.distributed_rank != 0: + return + + for checkpoint_path in checkpoint_paths: + self._save_checkpoint(checkpoint, checkpoint_path) + + def _save_checkpoint(self, checkpoint, checkpoint_path): + """ + Save a checkpoint while guarding against the job being killed in the middle + of checkpoint saving (which corrupts the checkpoint file and ruins the + entire training since usually only the last checkpoint is kept per run). + + We first save the new checkpoint to a temp file (with a '.tmp' suffix), and + and move it to overwrite the old checkpoint_path. + """ + checkpoint_path_tmp = f"{checkpoint_path}.tmp" + with g_pathmgr.open(checkpoint_path_tmp, "wb") as f: + torch.save(checkpoint, f) + # after torch.save is completed, replace the old checkpoint with the new one + if g_pathmgr.exists(checkpoint_path): + # remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails) + g_pathmgr.rm(checkpoint_path) + success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path) + assert success + + def load_checkpoint(self): + ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir) + if ckpt_path is None: + self._init_model_state() + else: + if self.checkpoint_conf.initialize_after_preemption: + self._call_model_initializer() + self._load_resuming_checkpoint(ckpt_path) + + def _init_model_state(self): + # Checking that parameters that won't be saved are indeed frozen + # We do this check here before even saving the model to catch errors + # are early as possible and not at the end of the first epoch + assert_skipped_parameters_are_frozen( + patterns=self.checkpoint_conf.skip_saving_parameters, + model=self.model, + ) + + # Checking that parameters that won't be saved are initialized from + # within the model definition, unless `initialize_after_preemption` + # is explicitly set to `True`. If not, this is a bug, and after + # preemption, the `skip_saving_parameters` will have random values + allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption + with with_check_parameter_frozen( + patterns=self.checkpoint_conf.skip_saving_parameters, + model=self.model, + disabled=allow_init_skip_parameters, + ): + self._call_model_initializer() + + def _call_model_initializer(self): + model_weight_initializer = instantiate( + self.checkpoint_conf.model_weight_initializer + ) + if model_weight_initializer is not None: + logging.info( + f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}" + ) + self.model = model_weight_initializer(model=self.model) + + def _load_resuming_checkpoint(self, ckpt_path: str): + logging.info(f"Resuming training from {ckpt_path}") + + with g_pathmgr.open(ckpt_path, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + load_state_dict_into_model( + model=self.model, + state_dict=checkpoint["model"], + ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters, + ) + + self.optim.optimizer.load_state_dict(checkpoint["optimizer"]) + self.loss.load_state_dict(checkpoint["loss"], strict=True) + self.epoch = checkpoint["epoch"] + self.steps = checkpoint["steps"] + self.ckpt_time_elapsed = checkpoint.get("time_elapsed") + + if self.optim_conf.amp.enabled and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.best_meter_values = checkpoint.get("best_meter_values", {}) + + if "train_dataset" in checkpoint and self.train_dataset is not None: + self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"]) + + def is_intermediate_val_epoch(self, epoch): + return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1 + + def _step( + self, + batch: BatchedVideoDatapoint, + model: nn.Module, + phase: str, + ): + + outputs = model(batch) + targets = batch.masks + batch_size = len(batch.img_batch) + + key = batch.dict_key # key for dataset + loss = self.loss[key](outputs, targets) + loss_str = f"Losses/{phase}_{key}_loss" + + loss_log_str = os.path.join("Step_Losses", loss_str) + + # loss contains multiple sub-components we wish to log + step_losses = {} + if isinstance(loss, dict): + step_losses.update( + {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()} + ) + loss = self._log_loss_detailed_and_return_core_loss( + loss, loss_log_str, self.steps[phase] + ) + + if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0: + self.logger.log( + loss_log_str, + loss, + self.steps[phase], + ) + + self.steps[phase] += 1 + + ret_tuple = {loss_str: loss}, batch_size, step_losses + + if phase in self.meters and key in self.meters[phase]: + meters_dict = self.meters[phase][key] + if meters_dict is not None: + for _, meter in meters_dict.items(): + meter.update( + find_stages=outputs, + find_metadatas=batch.metadata, + ) + + return ret_tuple + + def run(self): + assert self.mode in ["train", "train_only", "val"] + if self.mode == "train": + if self.epoch > 0: + logging.info(f"Resuming training from epoch: {self.epoch}") + # resuming from a checkpoint + if self.is_intermediate_val_epoch(self.epoch - 1): + logging.info("Running previous val epoch") + self.epoch -= 1 + self.run_val() + self.epoch += 1 + self.run_train() + self.run_val() + elif self.mode == "val": + self.run_val() + elif self.mode == "train_only": + self.run_train() + + def _setup_dataloaders(self): + self.train_dataset = None + self.val_dataset = None + + if self.mode in ["train", "val"]: + self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None)) + + if self.mode in ["train", "train_only"]: + self.train_dataset = instantiate(self.data_conf.train) + + def run_train(self): + + while self.epoch < self.max_epochs: + dataloader = self.train_dataset.get_loader(epoch=int(self.epoch)) + barrier() + outs = self.train_epoch(dataloader) + self.logger.log_dict(outs, self.epoch) # Logged only on rank 0 + + # log train to text file. + if self.distributed_rank == 0: + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "train_stats.json"), + "a", + ) as f: + f.write(json.dumps(outs) + "\n") + + # Save checkpoint before validating + self.save_checkpoint(self.epoch + 1) + + del dataloader + gc.collect() + + # Run val, not running on last epoch since will run after the + # loop anyway + if self.is_intermediate_val_epoch(self.epoch): + self.run_val() + + if self.distributed_rank == 0: + self.best_meter_values.update(self._get_trainer_state("train")) + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "best_stats.json"), + "a", + ) as f: + f.write(json.dumps(self.best_meter_values) + "\n") + + self.epoch += 1 + # epoch was incremented in the loop but the val step runs out of the loop + self.epoch -= 1 + + def run_val(self): + if not self.val_dataset: + return + + dataloader = self.val_dataset.get_loader(epoch=int(self.epoch)) + outs = self.val_epoch(dataloader, phase=Phase.VAL) + del dataloader + gc.collect() + self.logger.log_dict(outs, self.epoch) # Logged only on rank 0 + + if self.distributed_rank == 0: + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "val_stats.json"), + "a", + ) as f: + f.write(json.dumps(outs) + "\n") + + def val_epoch(self, val_loader, phase): + batch_time = AverageMeter("Batch Time", self.device, ":.2f") + data_time = AverageMeter("Data Time", self.device, ":.2f") + mem = MemMeter("Mem (GB)", self.device, ":.2f") + + iters_per_epoch = len(val_loader) + + curr_phases = [phase] + curr_models = [self.model] + + loss_names = [] + for p in curr_phases: + for key in self.loss.keys(): + loss_names.append(f"Losses/{p}_{key}_loss") + + loss_mts = OrderedDict( + [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] + ) + extra_loss_mts = {} + + for model in curr_models: + model.eval() + if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"): + unwrap_ddp_if_wrapped(model).on_validation_epoch_start() + + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()], + self._get_meters(curr_phases), + prefix="Val Epoch: [{}]".format(self.epoch), + ) + + end = time.time() + + for data_iter, batch in enumerate(val_loader): + + # measure data loading time + data_time.update(time.time() - end) + + batch = batch.to(self.device, non_blocking=True) + + # compute output + with torch.no_grad(): + with torch.cuda.amp.autocast( + enabled=(self.optim_conf.amp.enabled if self.optim_conf else False), + dtype=( + get_amp_type(self.optim_conf.amp.amp_dtype) + if self.optim_conf + else None + ), + ): + for phase, model in zip(curr_phases, curr_models): + loss_dict, batch_size, extra_losses = self._step( + batch, + model, + phase, + ) + + assert len(loss_dict) == 1 + loss_key, loss = loss_dict.popitem() + + loss_mts[loss_key].update(loss.item(), batch_size) + + for k, v in extra_losses.items(): + if k not in extra_loss_mts: + extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e") + extra_loss_mts[k].update(v.item(), batch_size) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + + if torch.cuda.is_available(): + mem.update(reset_peak_usage=True) + + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + if data_iter % self.logging_conf.log_scalar_frequency == 0: + # Log progress meters. + for progress_meter in progress.meters: + self.logger.log( + os.path.join("Step_Stats", phase, progress_meter.name), + progress_meter.val, + self.steps[Phase.VAL], + ) + + if data_iter % 10 == 0: + dist.barrier() + + self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch + self._log_timers(phase) + for model in curr_models: + if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"): + unwrap_ddp_if_wrapped(model).on_validation_epoch_end() + + out_dict = self._log_meters_and_save_best_ckpts(curr_phases) + + for k, v in loss_mts.items(): + out_dict[k] = v.avg + for k, v in extra_loss_mts.items(): + out_dict[k] = v.avg + + for phase in curr_phases: + out_dict.update(self._get_trainer_state(phase)) + self._reset_meters(curr_phases) + logging.info(f"Meters: {out_dict}") + return out_dict + + def _get_trainer_state(self, phase): + return { + "Trainer/where": self.where, + "Trainer/epoch": self.epoch, + f"Trainer/steps_{phase}": self.steps[phase], + } + + def train_epoch(self, train_loader): + + # Init stat meters + batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f") + data_time_meter = AverageMeter("Data Time", self.device, ":.2f") + mem_meter = MemMeter("Mem (GB)", self.device, ":.2f") + data_times = [] + phase = Phase.TRAIN + + iters_per_epoch = len(train_loader) + + loss_names = [] + for batch_key in self.loss.keys(): + loss_names.append(f"Losses/{phase}_{batch_key}_loss") + + loss_mts = OrderedDict( + [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] + ) + extra_loss_mts = {} + + progress = ProgressMeter( + iters_per_epoch, + [ + batch_time_meter, + data_time_meter, + mem_meter, + self.time_elapsed_meter, + *loss_mts.values(), + ], + self._get_meters([phase]), + prefix="Train Epoch: [{}]".format(self.epoch), + ) + + # Model training loop + self.model.train() + end = time.time() + + for data_iter, batch in enumerate(train_loader): + # measure data loading time + data_time_meter.update(time.time() - end) + data_times.append(data_time_meter.val) + batch = batch.to( + self.device, non_blocking=True + ) # move tensors in a tensorclass + + try: + self._run_step(batch, phase, loss_mts, extra_loss_mts) + + # compute gradient and do optim step + exact_epoch = self.epoch + float(data_iter) / iters_per_epoch + self.where = float(exact_epoch) / self.max_epochs + assert self.where <= 1 + self.EPSILON + if self.where < 1.0: + self.optim.step_schedulers( + self.where, step=int(exact_epoch * iters_per_epoch) + ) + else: + logging.warning( + f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]." + ) + + # Log schedulers + if data_iter % self.logging_conf.log_scalar_frequency == 0: + for j, param_group in enumerate(self.optim.optimizer.param_groups): + for option in self.optim.schedulers[j]: + optim_prefix = ( + "" + f"{j}_" + if len(self.optim.optimizer.param_groups) > 1 + else "" + ) + self.logger.log( + os.path.join("Optim", f"{optim_prefix}", option), + param_group[option], + self.steps[phase], + ) + + # Clipping gradients and detecting diverging gradients + if self.gradient_clipper is not None: + self.scaler.unscale_(self.optim.optimizer) + self.gradient_clipper(model=self.model) + + if self.gradient_logger is not None: + self.gradient_logger( + self.model, rank=self.distributed_rank, where=self.where + ) + + # Optimizer step: the scaler will make sure gradients are not + # applied if the gradients are infinite + self.scaler.step(self.optim.optimizer) + self.scaler.update() + + # measure elapsed time + batch_time_meter.update(time.time() - end) + end = time.time() + + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + + mem_meter.update(reset_peak_usage=True) + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + if data_iter % self.logging_conf.log_scalar_frequency == 0: + # Log progress meters. + for progress_meter in progress.meters: + self.logger.log( + os.path.join("Step_Stats", phase, progress_meter.name), + progress_meter.val, + self.steps[phase], + ) + + # Catching NaN/Inf errors in the loss + except FloatingPointError as e: + raise e + + self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch + self._log_timers(Phase.TRAIN) + self._log_sync_data_times(Phase.TRAIN, data_times) + + out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN]) + + for k, v in loss_mts.items(): + out_dict[k] = v.avg + for k, v in extra_loss_mts.items(): + out_dict[k] = v.avg + out_dict.update(self._get_trainer_state(phase)) + logging.info(f"Losses and meters: {out_dict}") + self._reset_meters([phase]) + return out_dict + + def _log_sync_data_times(self, phase, data_times): + data_times = all_reduce_max(torch.tensor(data_times)).tolist() + steps = range(self.steps[phase] - len(data_times), self.steps[phase]) + for step, data_time in zip(steps, data_times): + if step % self.logging_conf.log_scalar_frequency == 0: + self.logger.log( + os.path.join("Step_Stats", phase, "Data Time Synced"), + data_time, + step, + ) + + def _run_step( + self, + batch: BatchedVideoDatapoint, + phase: str, + loss_mts: Dict[str, AverageMeter], + extra_loss_mts: Dict[str, AverageMeter], + raise_on_error: bool = True, + ): + """ + Run the forward / backward + """ + + # it's important to set grads to None, especially with Adam since 0 + # grads will also update a model even if the step doesn't produce + # gradients + self.optim.zero_grad(set_to_none=True) + with torch.cuda.amp.autocast( + enabled=self.optim_conf.amp.enabled, + dtype=get_amp_type(self.optim_conf.amp.amp_dtype), + ): + loss_dict, batch_size, extra_losses = self._step( + batch, + self.model, + phase, + ) + + assert len(loss_dict) == 1 + loss_key, loss = loss_dict.popitem() + + if not math.isfinite(loss.item()): + error_msg = f"Loss is {loss.item()}, attempting to stop training" + logging.error(error_msg) + if raise_on_error: + raise FloatingPointError(error_msg) + else: + return + + self.scaler.scale(loss).backward() + loss_mts[loss_key].update(loss.item(), batch_size) + for extra_loss_key, extra_loss in extra_losses.items(): + if extra_loss_key not in extra_loss_mts: + extra_loss_mts[extra_loss_key] = AverageMeter( + extra_loss_key, self.device, ":.2e" + ) + extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size) + + def _log_meters_and_save_best_ckpts(self, phases: List[str]): + logging.info("Synchronizing meters") + out_dict = {} + checkpoint_save_keys = [] + for key, meter in self._get_meters(phases).items(): + meter_output = meter.compute_synced() + is_better_check = getattr(meter, "is_better", None) + + for meter_subkey, meter_value in meter_output.items(): + out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value + + if is_better_check is None: + continue + + tracked_meter_key = os.path.join(key, meter_subkey) + if tracked_meter_key not in self.best_meter_values or is_better_check( + meter_value, + self.best_meter_values[tracked_meter_key], + ): + self.best_meter_values[tracked_meter_key] = meter_value + + if ( + self.checkpoint_conf.save_best_meters is not None + and key in self.checkpoint_conf.save_best_meters + ): + checkpoint_save_keys.append(tracked_meter_key.replace("/", "_")) + + if len(checkpoint_save_keys) > 0: + self.save_checkpoint(self.epoch + 1, checkpoint_save_keys) + + return out_dict + + def _log_timers(self, phase): + time_remaining = 0 + epochs_remaining = self.max_epochs - self.epoch - 1 + val_epochs_remaining = sum( + n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs) + ) + + # Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with + # the end epoch. + if (self.max_epochs - 1) % self.val_epoch_freq != 0: + val_epochs_remaining += 1 + + # Remove the current val run from estimate + if phase == Phase.VAL: + val_epochs_remaining -= 1 + + time_remaining += ( + epochs_remaining * self.est_epoch_time[Phase.TRAIN] + + val_epochs_remaining * self.est_epoch_time[Phase.VAL] + ) + + self.logger.log( + os.path.join("Step_Stats", phase, self.time_elapsed_meter.name), + self.time_elapsed_meter.val, + self.steps[phase], + ) + + logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}") + + def _reset_meters(self, phases: str) -> None: + for meter in self._get_meters(phases).values(): + meter.reset() + + def _check_val_key_match(self, val_keys, phase): + if val_keys is not None: + # Check if there are any duplicates + assert len(val_keys) == len( + set(val_keys) + ), f"Duplicate keys in val datasets, keys: {val_keys}" + + # Check that the keys match the meter keys + if self.meters_conf is not None and phase in self.meters_conf: + assert set(val_keys) == set(self.meters_conf[phase].keys()), ( + f"Keys in val datasets do not match the keys in meters." + f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}" + f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}" + ) + + if self.loss_conf is not None: + loss_keys = set(self.loss_conf.keys()) - set(["all"]) + assert all([k in loss_keys for k in val_keys]), ( + f"Keys in val datasets do not match the keys in losses." + f"\nMissing in losses: {set(val_keys) - loss_keys}" + f"\nMissing in val datasets: {loss_keys - set(val_keys)}" + ) + + def _setup_components(self): + + # Get the keys for all the val datasets, if any + val_phase = Phase.VAL + val_keys = None + if self.data_conf.get(val_phase, None) is not None: + val_keys = collect_dict_keys(self.data_conf[val_phase]) + # Additional checks on the sanity of the config for val datasets + self._check_val_key_match(val_keys, phase=val_phase) + + logging.info("Setting up components: Model, loss, optim, meters etc.") + self.epoch = 0 + self.steps = {Phase.TRAIN: 0, Phase.VAL: 0} + + self.logger = Logger(self.logging_conf) + + self.model = instantiate(self.model_conf, _convert_="all") + print_model_summary(self.model) + + self.loss = None + if self.loss_conf: + self.loss = { + key: el # wrap_base_loss(el) + for (key, el) in instantiate(self.loss_conf, _convert_="all").items() + } + self.loss = nn.ModuleDict(self.loss) + + self.meters = {} + self.best_meter_values = {} + if self.meters_conf: + self.meters = instantiate(self.meters_conf, _convert_="all") + + self.scaler = torch.amp.GradScaler( + self.device, + enabled=self.optim_conf.amp.enabled if self.optim_conf else False, + ) + + self.gradient_clipper = ( + instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None + ) + self.gradient_logger = ( + instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None + ) + + logging.info("Finished setting up components: Model, loss, optim, meters etc.") + + def _construct_optimizers(self): + self.optim = construct_optimizer( + self.model, + self.optim_conf.optimizer, + self.optim_conf.options, + self.optim_conf.param_group_modifiers, + ) + + def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step): + core_loss = loss.pop(CORE_LOSS_KEY) + if step % self.logging_conf.log_scalar_frequency == 0: + for k in loss: + log_str = os.path.join(loss_str, k) + self.logger.log(log_str, loss[k], step) + return core_loss + + +def print_model_summary(model: torch.nn.Module, log_dir: str = ""): + """ + Prints the model and the number of parameters in the model. + # Multiple packages provide this info in a nice table format + # However, they need us to provide an `input` (as they also write down the output sizes) + # Our models are complex, and a single input is restrictive. + # https://github.com/sksq96/pytorch-summary + # https://github.com/nmhkahn/torchsummaryX + """ + if get_rank() != 0: + return + param_kwargs = {} + trainable_parameters = sum( + p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad + ) + total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs)) + non_trainable_parameters = total_parameters - trainable_parameters + logging.info("==" * 10) + logging.info(f"Summary for model {type(model)}") + logging.info(f"Model is {model}") + logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}") + logging.info( + f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}" + ) + logging.info( + f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}" + ) + logging.info("==" * 10) + + if log_dir: + output_fpath = os.path.join(log_dir, "model.txt") + with g_pathmgr.open(output_fpath, "w") as f: + print(model, file=f) + + +PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] + + +def get_human_readable_count(number: int) -> str: + """ + Abbreviates an integer number with K, M, B, T for thousands, millions, + billions and trillions, respectively. + Examples: + >>> get_human_readable_count(123) + '123 ' + >>> get_human_readable_count(1234) # (one thousand) + '1.2 K' + >>> get_human_readable_count(2e6) # (two million) + '2.0 M' + >>> get_human_readable_count(3e9) # (three billion) + '3.0 B' + >>> get_human_readable_count(4e14) # (four hundred trillion) + '400 T' + >>> get_human_readable_count(5e15) # (more than trillion) + '5,000 T' + Args: + number: a positive integer number + Return: + A string formatted according to the pattern described above. + """ + assert number >= 0 + labels = PARAMETER_NUM_UNITS + num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) + num_groups = int(np.ceil(num_digits / 3)) + num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions + shift = -3 * (num_groups - 1) + number = number * (10**shift) + index = num_groups - 1 + if index < 1 or number >= 100: + return f"{int(number):,d} {labels[index]}" + else: + return f"{number:,.1f} {labels[index]}" diff --git a/monst3r/third_party/sam2/training/utils/__init__.py b/monst3r/third_party/sam2/training/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/monst3r/third_party/sam2/training/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/monst3r/third_party/sam2/training/utils/checkpoint_utils.py b/monst3r/third_party/sam2/training/utils/checkpoint_utils.py new file mode 100644 index 0000000..f76689f --- /dev/null +++ b/monst3r/third_party/sam2/training/utils/checkpoint_utils.py @@ -0,0 +1,361 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import fnmatch +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import numpy as np +import torch +import torch.nn as nn +from iopath.common.file_io import g_pathmgr +from torch.jit._script import RecursiveScriptModule + + +def unix_pattern_to_parameter_names( + constraints: List[str], all_parameter_names: Sequence[str] +) -> Union[None, Set[str]]: + """ + Go through the list of parameter names and select those that match + any of the provided constraints + """ + parameter_names = [] + for param_name in constraints: + matching_parameters = set(fnmatch.filter(all_parameter_names, param_name)) + assert ( + len(matching_parameters) > 0 + ), f"param_names {param_name} don't match any param in the given names." + parameter_names.append(matching_parameters) + return set.union(*parameter_names) + + +def filter_params_matching_unix_pattern( + patterns: List[str], state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Remove from the state dictionary the parameters matching the provided unix patterns + + Args: + patterns: the list of unix patterns to exclude + state_dict: the dictionary to filter + + Returns: + A new state dictionary + """ + if len(patterns) == 0: + return {} + + all_keys = list(state_dict.keys()) + included_keys = unix_pattern_to_parameter_names(patterns, all_keys) + return {k: state_dict[k] for k in included_keys} + + +def exclude_params_matching_unix_pattern( + patterns: List[str], state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Remove from the state dictionary the parameters matching the provided unix patterns + + Args: + patterns: the list of unix patterns to exclude + state_dict: the dictionary to filter + + Returns: + A new state dictionary + """ + if len(patterns) == 0: + return state_dict + + all_keys = list(state_dict.keys()) + excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys) + return {k: v for k, v in state_dict.items() if k not in excluded_keys} + + +def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]): + keys = [] + trace = [] + for k, v in state_dict.items(): + keys.append(k) + trace.append(v.sum().item()) + trace = np.array(trace)[np.argsort(keys)] + return trace + + +def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]): + """ + Verifies that all the parameters matching the provided patterns + are frozen - this acts as a safeguard when ignoring parameter + when saving checkpoints - if the parameters are in fact trainable + """ + if not patterns: + return + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + non_frozen_keys = { + n + for n, p in model.named_parameters() + if n in frozen_state_dict and p.requires_grad + } + if non_frozen_keys: + raise ValueError( + f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}" + ) + + +@contextlib.contextmanager +def with_check_parameter_frozen( + model: nn.Module, patterns: List[str], disabled: bool = True +): + """ + Context manager that inspects a model surrounding a piece of code + and verifies if the model has been updated by this piece of code + + The function will raise an exception if the model has been updated + on at least one of the parameter that matches one of the pattern + + Args: + model: the model that might have been updated + patterns: for the parameters we want to observe + allowed: + """ + if not patterns or disabled: + yield + return + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + summary_before = _get_state_dict_summary(frozen_state_dict) + + yield + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + summary_after = _get_state_dict_summary(frozen_state_dict) + + if not np.allclose(summary_before, summary_after, atol=1e-6): + raise ValueError( + f""" + The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`. + You can resolve this error by either initializing those parameters from within the model definition + or using the flag `trainer.checkpoint.initialize_after_preemption` to True. + """ + ) + + +class CkptExcludeKernel: + """ + Removes the keys from the given model state_dict that match the key_pattern. + + Args: + key_pattern: Patterns used to select the keys in the state_dict + that are eligible for this kernel. + """ + + def __init__(self, key_pattern: List[str]): + self.key_pattern = key_pattern + + def __call__(self, state_dict: Dict): + """ + Args: + state_dict: A dictionary representing the given checkpoint's state dict. + """ + if len(self.key_pattern) == 0: + return state_dict + exclude_keys = unix_pattern_to_parameter_names( + self.key_pattern, state_dict.keys() + ) + return {k: v for k, v in state_dict.items() if k not in exclude_keys} + + +def load_checkpoint( + path_list: List[str], + pick_recursive_keys: Optional[List[str]] = None, + map_location: str = "cpu", +) -> Any: + """ + Loads a checkpoint from the specified path. + + Args: + path_list: A list of paths which contain the checkpoint. Each element + is tried (in order) until a file that exists is found. That file is then + used to read the checkpoint. + pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None. + For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"] + map_location (str): a function, torch.device, string or a dict specifying how to + remap storage locations + + Returns: Model with the matchin pre-trained weights loaded. + """ + path_exists = False + for path in path_list: + if g_pathmgr.exists(path): + path_exists = True + break + + if not path_exists: + raise ValueError(f"No path exists in {path_list}") + + with g_pathmgr.open(path, "rb") as f: + checkpoint = torch.load(f, map_location=map_location) + + logging.info(f"Loaded checkpoint from {path}") + if pick_recursive_keys is not None: + for key in pick_recursive_keys: + checkpoint = checkpoint[key] + return checkpoint + + +def get_state_dict(checkpoint, ckpt_state_dict_keys): + if isinstance(checkpoint, RecursiveScriptModule): + # This is a torchscript JIT model + return checkpoint.state_dict() + pre_train_dict = checkpoint + for i, key in enumerate(ckpt_state_dict_keys): + if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or ( + isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict) + ): + key_str = ( + '["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]' + ) + raise KeyError( + f"'{key}' not found in checkpoint{key_str} " + f"with keys: {pre_train_dict.keys()}" + ) + pre_train_dict = pre_train_dict[key] + return pre_train_dict + + +def load_checkpoint_and_apply_kernels( + checkpoint_path: str, + checkpoint_kernels: List[Callable] = None, + ckpt_state_dict_keys: Tuple[str] = ("state_dict",), + map_location: str = "cpu", +) -> nn.Module: + """ + Performs checkpoint loading with a variety of pre-processing kernel applied in + sequence. + + Args: + checkpoint_path (str): Path to the checkpoint. + checkpoint_kernels List(Callable): A list of checkpoint processing kernels + to apply in the specified order. Supported kernels include `CkptIncludeKernel`, + `CkptExcludeKernel`, etc. These kernels are applied in the + given order. + ckpt_state_dict_keys (str): Keys containing the model state dict. + map_location (str): a function, torch.device, string or a dict specifying how to + remap storage locations + + Returns: Model with the matchin pre-trained weights loaded. + """ + assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format( + checkpoint_path + ) + + # Load the checkpoint on CPU to avoid GPU mem spike. + with g_pathmgr.open(checkpoint_path, "rb") as f: + checkpoint = torch.load(f, map_location=map_location) + + pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys) + + # Not logging into info etc since it's a huge log + logging.debug( + "Loaded Checkpoint State Dict pre-kernel application: %s" + % str(", ".join(list(pre_train_dict.keys()))) + ) + # Apply kernels + if checkpoint_kernels is not None: + for f in checkpoint_kernels: + pre_train_dict = f(state_dict=pre_train_dict) + + logging.debug( + "Loaded Checkpoint State Dict Post-kernel application %s" + % str(", ".join(list(pre_train_dict.keys()))) + ) + + return pre_train_dict + + +def check_load_state_dict_errors( + missing_keys, + unexpected_keys, + strict: bool, + ignore_missing_keys: List[str] = None, + ignore_unexpected_keys: List[str] = None, +): + if ignore_missing_keys is not None and len(ignore_missing_keys) > 0: + ignored_keys = unix_pattern_to_parameter_names( + ignore_missing_keys, missing_keys + ) + missing_keys = [key for key in missing_keys if key not in ignored_keys] + + if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0: + ignored_unexpected_keys = unix_pattern_to_parameter_names( + ignore_unexpected_keys, unexpected_keys + ) + unexpected_keys = [ + key for key in unexpected_keys if key not in ignored_unexpected_keys + ] + + err = "State key mismatch." + if unexpected_keys: + err += f" Unexpected keys: {unexpected_keys}." + if missing_keys: + err += f" Missing keys: {missing_keys}." + + if unexpected_keys or missing_keys: + logging.warning(err) + if unexpected_keys or strict: + raise KeyError(err) + + +def load_state_dict_into_model( + state_dict: Dict, + model: nn.Module, + strict: bool = True, + ignore_missing_keys: List[str] = None, + ignore_unexpected_keys: List[str] = None, + checkpoint_kernels: List[Callable] = None, +): + """ + Loads a state dict into the given model. + + Args: + state_dict: A dictionary containing the model's + state dict, or a subset if strict is False + model: Model to load the checkpoint weights into + strict: raise if the state_dict has missing state keys + ignore_missing_keys: unix pattern of keys to ignore + """ + # Apply kernels + if checkpoint_kernels is not None: + for f in checkpoint_kernels: + state_dict = f(state_dict=state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + check_load_state_dict_errors( + missing_keys, + unexpected_keys, + strict=strict, + ignore_missing_keys=ignore_missing_keys, + ignore_unexpected_keys=ignore_unexpected_keys, + ) + return model diff --git a/monst3r/third_party/sam2/training/utils/data_utils.py b/monst3r/third_party/sam2/training/utils/data_utils.py new file mode 100644 index 0000000..fbd0115 --- /dev/null +++ b/monst3r/third_party/sam2/training/utils/data_utils.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from PIL import Image as PILImage +from tensordict import tensorclass + + +@tensorclass +class BatchedVideoMetaData: + """ + This class represents metadata about a batch of videos. + Attributes: + unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) + frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. + """ + + unique_objects_identifier: torch.LongTensor + frame_orig_size: torch.LongTensor + + +@tensorclass +class BatchedVideoDatapoint: + """ + This class represents a batch of videos with associated annotations and metadata. + Attributes: + img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. + obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. + masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. + metadata: An instance of BatchedVideoMetaData containing metadata about the batch. + dict_key: A string key used to identify the batch. + """ + + img_batch: torch.FloatTensor + obj_to_frame_idx: torch.IntTensor + masks: torch.BoolTensor + metadata: BatchedVideoMetaData + + dict_key: str + + def pin_memory(self, device=None): + return self.apply(torch.Tensor.pin_memory, device=device) + + @property + def num_frames(self) -> int: + """ + Returns the number of frames per video. + """ + return self.batch_size[0] + + @property + def num_videos(self) -> int: + """ + Returns the number of videos in the batch. + """ + return self.img_batch.shape[1] + + @property + def flat_obj_to_img_idx(self) -> torch.IntTensor: + """ + Returns a flattened tensor containing the object to img index. + The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] + """ + frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) + flat_idx = video_idx * self.num_frames + frame_idx + return flat_idx + + @property + def flat_img_batch(self) -> torch.FloatTensor: + """ + Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] + """ + + return self.img_batch.transpose(0, 1).flatten(0, 1) + + +@dataclass +class Object: + # Id of the object in the media + object_id: int + # Index of the frame in the media (0 if single image) + frame_index: int + segment: Union[torch.Tensor, dict] # RLE dict or binary mask + + +@dataclass +class Frame: + data: Union[torch.Tensor, PILImage.Image] + objects: List[Object] + + +@dataclass +class VideoDatapoint: + """Refers to an image/video and all its annotations""" + + frames: List[Frame] + video_id: int + size: Tuple[int, int] + + +def collate_fn( + batch: List[VideoDatapoint], + dict_key, +) -> BatchedVideoDatapoint: + """ + Args: + batch: A list of VideoDatapoint instances. + dict_key (str): A string key used to identify the batch. + """ + img_batch = [] + for video in batch: + img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] + + img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) + T = img_batch.shape[0] + # Prepare data structures for sequential processing. Per-frame processing but batched across videos. + step_t_objects_identifier = [[] for _ in range(T)] + step_t_frame_orig_size = [[] for _ in range(T)] + + step_t_masks = [[] for _ in range(T)] + step_t_obj_to_frame_idx = [ + [] for _ in range(T) + ] # List to store frame indices for each time step + + for video_idx, video in enumerate(batch): + orig_video_id = video.video_id + orig_frame_size = video.size + for t, frame in enumerate(video.frames): + objects = frame.objects + for obj in objects: + orig_obj_id = obj.object_id + orig_frame_idx = obj.frame_index + step_t_obj_to_frame_idx[t].append( + torch.tensor([t, video_idx], dtype=torch.int) + ) + step_t_masks[t].append(obj.segment.to(torch.bool)) + step_t_objects_identifier[t].append( + torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) + ) + step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) + + obj_to_frame_idx = torch.stack( + [ + torch.stack(obj_to_frame_idx, dim=0) + for obj_to_frame_idx in step_t_obj_to_frame_idx + ], + dim=0, + ) + masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) + objects_identifier = torch.stack( + [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 + ) + frame_orig_size = torch.stack( + [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 + ) + return BatchedVideoDatapoint( + img_batch=img_batch, + obj_to_frame_idx=obj_to_frame_idx, + masks=masks, + metadata=BatchedVideoMetaData( + unique_objects_identifier=objects_identifier, + frame_orig_size=frame_orig_size, + ), + dict_key=dict_key, + batch_size=[T], + ) diff --git a/monst3r/third_party/sam2/training/utils/distributed.py b/monst3r/third_party/sam2/training/utils/distributed.py new file mode 100644 index 0000000..f614b40 --- /dev/null +++ b/monst3r/third_party/sam2/training/utils/distributed.py @@ -0,0 +1,576 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import functools +import io +import logging +import os +import random +import tempfile +import time +from typing import Any, Callable, List, Tuple + +import torch +import torch.autograd as autograd +import torch.distributed as dist + + +# Default to GPU 0 +_cuda_device_index: int = 0 + +# Setting _cuda_device_index to -1 internally implies that we should use CPU +_CPU_DEVICE_INDEX = -1 +_PRIMARY_RANK = 0 + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + + if dist.get_backend() == "nccl": + # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes + # being much slower than others causing a timeout (which can happen in relation + # or LVIS class mAP evaluation). + timeout = 43200 + return dist.new_group( + backend="gloo", + timeout=datetime.timedelta(seconds=timeout), + ) + + return dist.group.WORLD + + +def is_main_process(): + """Return true if the current process is the main one""" + return get_rank() == 0 + + +def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors), similar to + `all_gather` above, but using filesystem instead of collective ops. + + If gather_to_rank_0_only is True, only rank 0 will load the gathered object list + (and other ranks will have an empty list). + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + print("gathering via files") + cpu_group = _get_global_gloo_group() + + # if unspecified, we will save to the current python file dir + if filesys_save_dir is not None: + save_dir = filesys_save_dir + elif "EXP_DIR" in os.environ: + save_dir = os.environ["EXP_DIR"] + else: + # try the same directory where the code is stored + save_dir = filesys_save_dir or os.path.dirname(__file__) + save_dir = os.path.join(save_dir, "all_gather_via_filesys") + if is_main_process(): + os.makedirs(save_dir, exist_ok=True) + + # use a timestamp and salt to distinguish different all_gather + timestamp = int(time.time()) if is_main_process() else 0 + salt = random.randint(0, 2**31 - 1) if is_main_process() else 0 + # broadcast the timestamp and salt across ranks + # (all-reduce will do the broadcasting since only rank 0 is non-zero) + timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long) + dist.all_reduce(timestamp_and_salt, group=cpu_group) + timestamp, salt = timestamp_and_salt.tolist() + + # save the data to a file on the disk + rank_save = get_rank() + save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl" + save_data_path = os.path.join(save_dir, save_data_filename) + assert not os.path.exists(save_data_path), f"{save_data_path} already exists" + torch.save(data, save_data_path) + dist.barrier(group=cpu_group) + + # read the data from the files + data_list = [] + if rank_save == 0 or not gather_to_rank_0_only: + for rank_load in range(world_size): + load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl" + load_data_path = os.path.join(save_dir, load_data_filename) + assert os.path.exists(load_data_path), f"cannot read {save_data_path}" + data_list.append(torch.load(load_data_path)) + dist.barrier(group=cpu_group) + + # delete the saved file + os.remove(save_data_path) + return data_list + + +def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + + world_size = get_world_size() + if world_size == 1: + return [data] + + if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1": + return all_gather_via_filesys( + data, filesys_save_dir, gather_to_rank_0_only=True + ) + + if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys: + return all_gather_via_filesys(data, filesys_save_dir) + + cpu_group = None + if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu: + cpu_group = _get_global_gloo_group() + + buffer = io.BytesIO() + torch.save(data, buffer) + data_view = buffer.getbuffer() + device = "cuda" if cpu_group is None else "cpu" + tensor = torch.ByteTensor(data_view).to(device) + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) + size_list = [ + torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size) + ] + if cpu_group is None: + dist.all_gather(size_list, local_size) + else: + print("gathering on cpu") + dist.all_gather(size_list, local_size, group=cpu_group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + assert isinstance(local_size.item(), int) + local_size = int(local_size.item()) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) + if local_size != max_size: + padding = torch.empty( + size=(max_size - local_size,), dtype=torch.uint8, device=device + ) + tensor = torch.cat((tensor, padding), dim=0) + if cpu_group is None: + dist.all_gather(tensor_list, tensor) + else: + dist.all_gather(tensor_list, tensor, group=cpu_group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] + buffer = io.BytesIO(tensor.cpu().numpy()) + obj = torch.load(buffer) + data_list.append(obj) + + return data_list + + +def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This helper function converts to the correct + device and returns the tensor + original device. + """ + orig_device = "cpu" if not tensor.is_cuda else "gpu" + if ( + torch.distributed.is_available() + and torch.distributed.get_backend() == torch.distributed.Backend.NCCL + and not tensor.is_cuda + ): + tensor = tensor.cuda() + return (tensor, orig_device) + + +def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This converts the tensor back to original device. + """ + if tensor.is_cuda and orig_device == "cpu": + tensor = tensor.cpu() + return tensor + + +def is_distributed_training_run() -> bool: + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and (torch.distributed.get_world_size() > 1) + ) + + +def is_primary() -> bool: + """ + Returns True if this is rank 0 of a distributed training job OR if it is + a single trainer job. Otherwise False. + """ + return get_rank() == _PRIMARY_RANK + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing mean reduction + of tensor over all processes. + """ + return all_reduce_op( + tensor, + torch.distributed.ReduceOp.SUM, + lambda t: t / torch.distributed.get_world_size(), + ) + + +def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing sum + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM) + + +def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing min + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN) + + +def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing min + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX) + + +def all_reduce_op( + tensor: torch.Tensor, + op: torch.distributed.ReduceOp, + after_op_func: Callable[[torch.Tensor], torch.Tensor] = None, +) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + torch.distributed.all_reduce(tensor, op) + if after_op_func is not None: + tensor = after_op_func(tensor) + tensor = convert_to_normal_tensor(tensor, orig_device) + return tensor + + +def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Wrapper over torch.distributed.all_gather for performing + 'gather' of 'tensor' over all processes in both distributed / + non-distributed scenarios. + """ + if tensor.ndim == 0: + # 0 dim tensors cannot be gathered. so unsqueeze + tensor = tensor.unsqueeze(0) + + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + gathered_tensors = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(gathered_tensors, tensor) + gathered_tensors = [ + convert_to_normal_tensor(_tensor, orig_device) + for _tensor in gathered_tensors + ] + else: + gathered_tensors = [tensor] + + return gathered_tensors + + +def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: + gathered_tensors = gather_tensors_from_all(tensor) + gathered_tensor = torch.cat(gathered_tensors, 0) + return gathered_tensor + + +def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """ + Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source + to all processes in both distributed / non-distributed scenarios. + """ + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + torch.distributed.broadcast(tensor, src) + tensor = convert_to_normal_tensor(tensor, orig_device) + return tensor + + +def barrier() -> None: + """ + Wrapper over torch.distributed.barrier, returns without waiting + if the distributed process group is not initialized instead of throwing error. + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return + torch.distributed.barrier() + + +def get_world_size() -> int: + """ + Simple wrapper for correctly getting worldsize in both distributed + / non-distributed settings + """ + return ( + torch.distributed.get_world_size() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 1 + ) + + +def get_rank() -> int: + """ + Simple wrapper for correctly getting rank in both distributed + / non-distributed settings + """ + return ( + torch.distributed.get_rank() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 0 + ) + + +def get_primary_rank() -> int: + return _PRIMARY_RANK + + +def set_cuda_device_index(idx: int) -> None: + global _cuda_device_index + _cuda_device_index = idx + torch.cuda.set_device(_cuda_device_index) + + +def set_cpu_device() -> None: + global _cuda_device_index + _cuda_device_index = _CPU_DEVICE_INDEX + + +def get_cuda_device_index() -> int: + return _cuda_device_index + + +def init_distributed_data_parallel_model( + model: torch.nn.Module, + broadcast_buffers: bool = False, + find_unused_parameters: bool = True, + bucket_cap_mb: int = 25, +) -> torch.nn.parallel.DistributedDataParallel: + global _cuda_device_index + + if _cuda_device_index == _CPU_DEVICE_INDEX: + # CPU-only model, don't specify device + return torch.nn.parallel.DistributedDataParallel( + model, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + bucket_cap_mb=bucket_cap_mb, + ) + else: + # GPU model + return torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[_cuda_device_index], + output_device=_cuda_device_index, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + bucket_cap_mb=bucket_cap_mb, + ) + + +def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any: + """Broadcast an object from a source to all workers. + + Args: + obj: Object to broadcast, must be serializable + src: Source rank for broadcast (default is primary) + use_disk: If enabled, removes redundant CPU memory copies by writing to + disk + """ + # Either broadcast from primary to the fleet (default), + # or use the src setting as the original rank + if get_rank() == src: + # Emit data + buffer = io.BytesIO() + torch.save(obj, buffer) + data_view = buffer.getbuffer() + length_tensor = torch.LongTensor([len(data_view)]) + length_tensor = broadcast(length_tensor, src=src) + data_tensor = torch.ByteTensor(data_view) + data_tensor = broadcast(data_tensor, src=src) + else: + # Fetch from the source + length_tensor = torch.LongTensor([0]) + length_tensor = broadcast(length_tensor, src=src) + data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8) + data_tensor = broadcast(data_tensor, src=src) + if use_disk: + with tempfile.TemporaryFile("r+b") as f: + f.write(data_tensor.numpy()) + # remove reference to the data tensor and hope that Python garbage + # collects it + del data_tensor + f.seek(0) + obj = torch.load(f) + else: + buffer = io.BytesIO(data_tensor.numpy()) + obj = torch.load(buffer) + return obj + + +def all_gather_tensor(tensor: torch.Tensor, world_size=None): + if world_size is None: + world_size = get_world_size() + # make contiguous because NCCL won't gather the tensor otherwise + assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!" + tensor, orig_device = convert_to_distributed_tensor(tensor) + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_all, tensor, async_op=False) # performance opt + tensor_all = [ + convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all + ] + return tensor_all + + +def all_gather_batch(tensors: List[torch.Tensor]): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + for tensor in tensors: + tensor_all = all_gather_tensor(tensor, world_size) + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor + + +class GatherLayer(autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + dist.all_reduce(all_gradients) + return all_gradients[dist.get_rank()] + + +def all_gather_batch_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + + for tensor in tensors: + tensor_all = GatherLayer.apply(tensor) + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor + + +def unwrap_ddp_if_wrapped(model): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + return model + + +def create_new_process_group(group_size): + """ + Creates process groups of a gives `group_size` and returns + process group that current GPU participates in. + + `group_size` must divide the total number of GPUs (world_size). + + Modified from + https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60 + + Args: + group_size (int): number of GPU's to collaborate for sync bn + """ + + assert group_size > 0 + + world_size = torch.distributed.get_world_size() + if world_size <= 8: + if group_size > world_size: + logging.warning( + f"Requested group size [{group_size}] > world size [{world_size}]. " + "Assuming local debug run and capping it to world size." + ) + group_size = world_size + assert world_size >= group_size + assert world_size % group_size == 0 + + group = None + for group_num in range(world_size // group_size): + group_ids = range(group_num * group_size, (group_num + 1) * group_size) + cur_group = torch.distributed.new_group(ranks=group_ids) + if torch.distributed.get_rank() // group_size == group_num: + group = cur_group + # can not drop out and return here, every process must go through creation of all subgroups + + assert group is not None + return group + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True diff --git a/monst3r/third_party/sam2/training/utils/logger.py b/monst3r/third_party/sam2/training/utils/logger.py new file mode 100644 index 0000000..f4b4ef0 --- /dev/null +++ b/monst3r/third_party/sam2/training/utils/logger.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py +import atexit +import functools +import logging +import sys +import uuid +from typing import Any, Dict, Optional, Union + +from hydra.utils import instantiate + +from iopath.common.file_io import g_pathmgr +from numpy import ndarray +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter + +from training.utils.train_utils import get_machine_local_and_dist_rank, makedir + +Scalar = Union[Tensor, ndarray, int, float] + + +def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any): + makedir(log_dir) + summary_writer_method = SummaryWriter + return TensorBoardLogger( + path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs + ) + + +class TensorBoardWriterWrapper: + """ + A wrapper around a SummaryWriter object. + """ + + def __init__( + self, + path: str, + *args: Any, + filename_suffix: str = None, + summary_writer_method: Any = SummaryWriter, + **kwargs: Any, + ) -> None: + """Create a new TensorBoard logger. + On construction, the logger creates a new events file that logs + will be written to. If the environment variable `RANK` is defined, + logger will only log if RANK = 0. + + NOTE: If using the logger with distributed training: + - This logger can call collective operations + - Logs will be written on rank 0 only + - Logger must be constructed synchronously *after* initializing distributed process group. + + Args: + path (str): path to write logs to + *args, **kwargs: Extra arguments to pass to SummaryWriter + """ + self._writer: Optional[SummaryWriter] = None + _, self._rank = get_machine_local_and_dist_rank() + self._path: str = path + if self._rank == 0: + logging.info( + f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" + ) + self._writer = summary_writer_method( + log_dir=path, + *args, + filename_suffix=filename_suffix or str(uuid.uuid4()), + **kwargs, + ) + else: + logging.debug( + f"Not logging meters on this host because env RANK: {self._rank} != 0" + ) + atexit.register(self.close) + + @property + def writer(self) -> Optional[SummaryWriter]: + return self._writer + + @property + def path(self) -> str: + return self._path + + def flush(self) -> None: + """Writes pending logs to disk.""" + + if not self._writer: + return + + self._writer.flush() + + def close(self) -> None: + """Close writer, flushing pending logs to disk. + Logs cannot be written after `close` is called. + """ + + if not self._writer: + return + + self._writer.close() + self._writer = None + + +class TensorBoardLogger(TensorBoardWriterWrapper): + """ + A simple logger for TensorBoard. + """ + + def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: + """Add multiple scalar values to TensorBoard. + + Args: + payload (dict): dictionary of tag name and scalar value + step (int, Optional): step value to record + """ + if not self._writer: + return + for k, v in payload.items(): + self.log(k, v, step) + + def log(self, name: str, data: Scalar, step: int) -> None: + """Add scalar data to TensorBoard. + + Args: + name (string): tag name used to group scalars + data (float/int/Tensor): scalar data to log + step (int, optional): step value to record + """ + if not self._writer: + return + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def log_hparams( + self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] + ) -> None: + """Add hyperparameter data to TensorBoard. + + Args: + hparams (dict): dictionary of hyperparameter names and corresponding values + meters (dict): dictionary of name of meter and corersponding values + """ + if not self._writer: + return + self._writer.add_hparams(hparams, meters) + + +class Logger: + """ + A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger. + """ + + def __init__(self, logging_conf): + # allow turning off TensorBoard with "should_log: false" in config + tb_config = logging_conf.tensorboard_writer + tb_should_log = tb_config and tb_config.pop("should_log", True) + self.tb_logger = instantiate(tb_config) if tb_should_log else None + + def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: + if self.tb_logger: + self.tb_logger.log_dict(payload, step) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self.tb_logger: + self.tb_logger.log(name, data, step) + + def log_hparams( + self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] + ) -> None: + if self.tb_logger: + self.tb_logger.log_hparams(hparams, meters) + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + # we tune the buffering value so that the logs are updated + # frequently. + log_buffer_kb = 10 * 1024 # 10KB + io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) + atexit.register(io.close) + return io + + +def setup_logging( + name, + output_dir=None, + rank=0, + log_level_primary="INFO", + log_level_secondary="ERROR", +): + """ + Setup various logging streams: stdout and file handlers. + For file handlers, we only setup for the master gpu. + """ + # get the filename if we want to log to the file as well + log_filename = None + if output_dir: + makedir(output_dir) + if rank == 0: + log_filename = f"{output_dir}/log.txt" + + logger = logging.getLogger(name) + logger.setLevel(log_level_primary) + + # create formatter + FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" + formatter = logging.Formatter(FORMAT) + + # Cleanup any existing handlers + for h in logger.handlers: + logger.removeHandler(h) + logger.root.handlers = [] + + # setup the console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + if rank == 0: + console_handler.setLevel(log_level_primary) + else: + console_handler.setLevel(log_level_secondary) + + # we log to file as well if user wants + if log_filename and rank == 0: + file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) + file_handler.setLevel(log_level_primary) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + logging.root = logger + + +def shutdown_logging(): + """ + After training is done, we ensure to shut down all the logger streams. + """ + logging.info("Shutting down loggers...") + handlers = logging.root.handlers + for handler in handlers: + handler.close() diff --git a/monst3r/third_party/sam2/training/utils/train_utils.py b/monst3r/third_party/sam2/training/utils/train_utils.py new file mode 100644 index 0000000..91d5577 --- /dev/null +++ b/monst3r/third_party/sam2/training/utils/train_utils.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import os +import random +import re +from datetime import timedelta +from typing import Optional + +import hydra + +import numpy as np +import omegaconf +import torch +import torch.distributed as dist +from iopath.common.file_io import g_pathmgr +from omegaconf import OmegaConf + + +def multiply_all(*args): + return np.prod(np.array(args)).item() + + +def collect_dict_keys(config): + """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" + val_keys = [] + # If the this config points to the collate function, then it has a key + if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): + val_keys.append(config["dict_key"]) + else: + # Recursively proceed + for v in config.values(): + if isinstance(v, type(config)): + val_keys.extend(collect_dict_keys(v)) + elif isinstance(v, omegaconf.listconfig.ListConfig): + for item in v: + if isinstance(item, type(config)): + val_keys.extend(collect_dict_keys(item)) + return val_keys + + +class Phase: + TRAIN = "train" + VAL = "val" + + +def register_omegaconf_resolvers(): + OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) + OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) + OmegaConf.register_new_resolver("add", lambda x, y: x + y) + OmegaConf.register_new_resolver("times", multiply_all) + OmegaConf.register_new_resolver("divide", lambda x, y: x / y) + OmegaConf.register_new_resolver("pow", lambda x, y: x**y) + OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) + OmegaConf.register_new_resolver("range", lambda x: list(range(x))) + OmegaConf.register_new_resolver("int", lambda x: int(x)) + OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) + OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) + + +def setup_distributed_backend(backend, timeout_mins): + """ + Initialize torch.distributed and set the CUDA device. + Expects environment variables to be set as per + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. + """ + # enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins + # of waiting + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") + dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) + return dist.get_rank() + + +def get_machine_local_and_dist_rank(): + """ + Get the distributed and local rank of the current gpu. + """ + local_rank = int(os.environ.get("LOCAL_RANK", None)) + distributed_rank = int(os.environ.get("RANK", None)) + assert ( + local_rank is not None and distributed_rank is not None + ), "Please the set the RANK and LOCAL_RANK environment variables." + return local_rank, distributed_rank + + +def print_cfg(cfg): + """ + Supports printing both Hydra DictConfig and also the AttrDict config + """ + logging.info("Training with config:") + logging.info(OmegaConf.to_yaml(cfg)) + + +def set_seeds(seed_value, max_epochs, dist_rank): + """ + Set the python random, numpy and torch seed for each gpu. Also set the CUDA + seeds if the CUDA is available. This ensures deterministic nature of the training. + """ + # Since in the pytorch sampler, we increment the seed by 1 for every epoch. + seed_value = (seed_value + dist_rank) * max_epochs + logging.info(f"MACHINE SEED: {seed_value}") + random.seed(seed_value) + np.random.seed(seed_value) + torch.manual_seed(seed_value) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed_value) + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_amp_type(amp_type: Optional[str] = None): + if amp_type is None: + return None + assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." + if amp_type == "bfloat16": + return torch.bfloat16 + else: + return torch.float16 + + +def log_env_variables(): + env_keys = sorted(list(os.environ.keys())) + st = "" + for k in env_keys: + v = os.environ[k] + st += f"{k}={v}\n" + logging.info("Logging ENV_VARIABLES") + logging.info(st) + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, device, fmt=":f"): + self.name = name + self.fmt = fmt + self.device = device + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self._allow_updates = True + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class MemMeter: + """Computes and stores the current, avg, and max of peak Mem usage per iteration""" + + def __init__(self, name, device, fmt=":f"): + self.name = name + self.fmt = fmt + self.device = device + self.reset() + + def reset(self): + self.val = 0 # Per iteration max usage + self.avg = 0 # Avg per iteration max usage + self.peak = 0 # Peak usage for lifetime of program + self.sum = 0 + self.count = 0 + self._allow_updates = True + + def update(self, n=1, reset_peak_usage=True): + self.val = torch.cuda.max_memory_allocated() // 1e9 + self.sum += self.val * n + self.count += n + self.avg = self.sum / self.count + self.peak = max(self.peak, self.val) + if reset_peak_usage: + torch.cuda.reset_peak_memory_stats() + + def __str__(self): + fmtstr = ( + "{name}: {val" + + self.fmt + + "} ({avg" + + self.fmt + + "}/{peak" + + self.fmt + + "})" + ) + return fmtstr.format(**self.__dict__) + + +def human_readable_time(time_seconds): + time = int(time_seconds) + minutes, seconds = divmod(time, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + return f"{days:02}d {hours:02}h {minutes:02}m" + + +class DurationMeter: + def __init__(self, name, device, fmt=":f"): + self.name = name + self.device = device + self.fmt = fmt + self.val = 0 + + def reset(self): + self.val = 0 + + def update(self, val): + self.val = val + + def add(self, val): + self.val += val + + def __str__(self): + return f"{self.name}: {human_readable_time(self.val)}" + + +class ProgressMeter: + def __init__(self, num_batches, meters, real_meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.real_meters = real_meters + self.prefix = prefix + + def display(self, batch, enable_print=False): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + entries += [ + " | ".join( + [ + f"{os.path.join(name, subname)}: {val:.4f}" + for subname, val in meter.compute().items() + ] + ) + for name, meter in self.real_meters.items() + ] + logging.info(" | ".join(entries)) + if enable_print: + print(" | ".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def get_resume_checkpoint(checkpoint_save_dir): + if not g_pathmgr.isdir(checkpoint_save_dir): + return None + ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") + if not g_pathmgr.isfile(ckpt_file): + return None + + return ckpt_file diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000..a40e34a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +torch==2.4.0 +torchvision>=0.19.0 +transformers==4.42.2 +accelerate>=0.34.2 +git+https://github.com/microsoft/DeepSpeed.git +git+https://github.com/facebookresearch/pytorchvideo.git + + +opencv-python +pillow==9.5.0 +imageio>=2.35.1 +imageio-ffmpeg==0.5.1 +ffmpeg-python +moviepy>=1.0.3 +decord==0.6.0 +kornia==0.7.3 +scikit-image + +git+https://github.com/openai/CLIP.git +openai>=1.45.0 +peft +xformers +bitsandbytes +sentencepiece>=0.2.0 + +numpy==1.26.0 +scipy +einops +rotary_embedding_torch +pytorch_lightning +easydict +pandas +matplotlib +gradio==3.40.0 +wandb +nvitop +omegaconf==2.3.0 +beartype==0.18.5 +fsspec==2024.5.0 +safetensors==0.4.3 +huggingface_hub==0.24.1 +av==13.1.0 +gdown \ No newline at end of file diff --git a/train/ds_config.json b/train/ds_config.json new file mode 100644 index 0000000..bec2010 --- /dev/null +++ b/train/ds_config.json @@ -0,0 +1,21 @@ +{ + "gradient_accumulation_steps": 1, + "gradient_clipping": 1.0, + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "offload_param": { + "device": "cpu" + } + }, + "steps_per_print": "inf", + "bf16": { + "enabled": true + }, + "fp16": { + "enabled": false + } +} \ No newline at end of file diff --git a/train/mask_process.py b/train/mask_process.py new file mode 100644 index 0000000..69a3867 --- /dev/null +++ b/train/mask_process.py @@ -0,0 +1,322 @@ +import numpy as np +import cv2 +import matplotlib.pyplot as plt + +from PIL import Image, ImageDraw +import math + +def generate_random_brush(h, w): + """生成随机笔画mask""" + mask = Image.new('L', (w, h), 0) + average_radius = math.sqrt(h*h+w*w) / 8 + max_tries = 5 + min_num_vertex, max_num_vertex = 1, 8 + mean_angle = 2*math.pi / 5 + angle_range = 2*math.pi / 15 + min_width, max_width = 128, 256 + # min_width = min(h, w) // 5 # 使用图像最大边长的1/5作为笔画宽度 + # max_width = min_width + + num_tries = np.random.choice([_ for _ in range(max_tries)], p=[0.05, 0.3, 0.3, 0.3, 0.05]) + for _ in range(num_tries): + num_vertex = np.random.randint(min_num_vertex, max_num_vertex) + angle_min = mean_angle - np.random.uniform(0, angle_range) + angle_max = mean_angle + np.random.uniform(0, angle_range) + angles = [] + vertex = [] + + for i in range(num_vertex): + if i % 2 == 0: + angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) + else: + angles.append(np.random.uniform(angle_min, angle_max)) + + vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) + for i in range(num_vertex): + r = np.clip( + np.random.normal(loc=average_radius, scale=average_radius//2), + 0, 2*average_radius) + new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) + new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) + vertex.append((int(new_x), int(new_y))) + + draw = ImageDraw.Draw(mask) + width = int(np.random.uniform(min_width, max_width)) + draw.line(vertex, fill=1, width=width) + for v in vertex: + draw.ellipse((v[0] - width//2, + v[1] - width//2, + v[0] + width//2, + v[1] + width//2), + fill=1) + + mask = np.asarray(mask, np.uint8) + if np.random.random() > 0.5: + mask = np.flip(mask, 0) + if np.random.random() > 0.5: + mask = np.flip(mask, 1) + return mask + +def transform_video_masks( + video_masks, + p_brush=0.25, + p_rect=0.25, + p_ellipse=0.2, + p_circle=0.2, + p_random_brush=0.1, + margin_ratio=0.1, + shape_scale_min=1.1, + shape_scale_max=1.5, + brush_iterations=1 +): + """ + 转换视频序列中的masks,整个视频序列使���相同的变换参数 + + 参数: + video_masks: numpy array, shape (F, H, W, C) + """ + F, H, W, C = video_masks.shape + transformed_video_masks = np.zeros_like(video_masks) + + # 为整个视频序列选择一种mask变化模式 + choice = np.random.choice( + ['brush', 'rect', 'ellipse', 'circle', 'random_brush'], + p=[p_brush, p_rect, p_ellipse, p_circle, p_random_brush] + ) + + # 预先确定所有随机参数 + if choice == 'brush': + # 选择一种形态学操作类型 + morph_type = np.random.choice([ + 'dilate_erode', # 闭运算 + 'erode_dilate', # 开运算 + 'dilate_only', # 仅膨胀 + 'combined' # 组合操作 + ]) + # 决定是否使用高斯模糊 + use_blur = np.random.random() < 0.1 + # print(f"choice: {choice}, morph_type: {morph_type}, use_blur: {use_blur}") + + elif choice == 'random_brush': + # 为整个序列生成一个随机笔画mask + first_frame_brush = generate_random_brush(H, W) + # print(f"choice: {choice}") + + elif choice == 'rect': + # 预定义矩形参数 + rect_angle = np.random.uniform(0, 360) + width_scale = np.random.uniform(shape_scale_min, shape_scale_max) + height_scale = np.random.uniform(shape_scale_min, shape_scale_max) + + elif choice == 'ellipse': + # 预定义椭圆参数 + width_scale = np.random.uniform(shape_scale_min/2, shape_scale_max/2) + height_scale = np.random.uniform(shape_scale_min/2, shape_scale_max/2) + angle = np.random.uniform(0, 360) + + else: # circle + # 预定义圆形参数 + radius_scale = np.random.uniform(shape_scale_min/2, shape_scale_max/2) + + # 为rect、ellipse和circle模式预先生成mask + if choice in ['rect', 'ellipse', 'circle']: + # 使用第一帧来获取边界框 + first_frame = video_masks[0] + H, W, C = first_frame.shape + y_indices, x_indices = np.where(first_frame[:,:,0] > 0) + if len(y_indices) == 0 or len(x_indices) == 0: + return video_masks + + x_min, x_max = np.min(x_indices), np.max(x_indices) + y_min, y_max = np.min(y_indices), np.max(y_indices) + + # 添加边界扰动 + margin = int(min(H, W) * margin_ratio) + x_min = max(0, x_min - np.random.randint(0, margin)) + x_max = min(W, x_max + np.random.randint(0, margin)) + y_min = max(0, y_min - np.random.randint(0, margin)) + y_max = min(H, y_max + np.random.randint(0, margin)) + + center_x = (x_min + x_max) // 2 + center_y = (y_min + y_max) // 2 + width = x_max - x_min + height = y_max - y_min + + # 预先生成形状mask + first_frame_shape = np.zeros((H, W), dtype=np.uint8) + + if choice == 'rect': + rect = ( + (float(center_x), float(center_y)), + (float(width * width_scale), float(height * height_scale)), + float(rect_angle) + ) + box = cv2.boxPoints(rect).astype(np.int32) + cv2.fillPoly(first_frame_shape, [box], 1) + + elif choice == 'ellipse': + axes_length = (int(width * width_scale), int(height * height_scale)) + cv2.ellipse(first_frame_shape, + (center_x, center_y), + axes_length, + angle, + 0, 360, 1, -1) + + else: # circle + radius = int(max(width, height) * radius_scale) + cv2.circle(first_frame_shape, + (center_x, center_y), + radius, + 1, -1) + + def transform_frame(mask): + H, W, C = mask.shape + transformed_mask = np.zeros((H, W, C), dtype=np.uint8) + + if choice == 'random_brush': + transformed_mask[:,:,0] = first_frame_brush + elif choice in ['rect', 'ellipse', 'circle']: + transformed_mask[:,:,0] = first_frame_shape + elif choice == 'brush': + kernel = np.ones((32, 32), np.uint8) + + if morph_type == 'dilate_erode': + dilated = cv2.dilate(mask[:,:,0].astype(np.uint8), kernel, iterations=brush_iterations) + transformed_mask[:,:,0] = cv2.erode(dilated, kernel, iterations=brush_iterations) + + elif morph_type == 'erode_dilate': + eroded = cv2.erode(mask[:,:,0].astype(np.uint8), kernel, iterations=brush_iterations) + transformed_mask[:,:,0] = cv2.dilate(eroded, kernel, iterations=brush_iterations) + + elif morph_type == 'dilate_only': + transformed_mask[:,:,0] = cv2.dilate(mask[:,:,0].astype(np.uint8), kernel, iterations=brush_iterations) + + else: # combined + eroded = cv2.erode(mask[:,:,0].astype(np.uint8), kernel, iterations=brush_iterations) + opened = cv2.dilate(eroded, kernel, iterations=brush_iterations) + dilated = cv2.dilate(opened, kernel, iterations=brush_iterations) + transformed_mask[:,:,0] = cv2.erode(dilated, kernel, iterations=brush_iterations) + + if use_blur: + blur_size = max(3, 8 // 4) + if blur_size % 2 == 0: + blur_size += 1 + transformed_mask[:,:,0] = cv2.GaussianBlur(transformed_mask[:,:,0], (blur_size, blur_size), 0) + transformed_mask = (transformed_mask > 0.5).astype(np.uint8) + + # 复制到其他通道 + transformed_mask[:,:,1:] = transformed_mask[:,:,0:1] + return transformed_mask + + # 处理每一帧 + for f in range(F): + transformed_video_masks[f] = transform_frame(video_masks[f]) + + return transformed_video_masks + +def test_transform_video_masks(): + """测试mask变换函数""" + # 创建测试用的视频mask序列 + F, H, W, C = 10, 480, 720, 3 + video_masks = np.zeros((F, H, W, C), dtype=np.uint8) + + # 创建一些模拟的分割mask + for f in range(F): + # 创建基础mask + frame_mask = np.zeros((H, W), dtype=np.uint8) + + # 定义缩放因子,使mask更大 + scale = min(H, W) // 2 # 根据图像尺寸自适应缩放 + + # 添加一个不规则形状(模拟人物轮廓) + points = np.array([ + [scale * (0.5 + np.random.randint(-5, 5)/100), scale * (0.3 + np.random.randint(-2, 2)/100)], # 头部 + [scale * (0.3 + np.random.randint(-3, 3)/100), scale * (0.5 + np.random.randint(-2, 2)/100)], # 左肩 + [scale * (0.7 + np.random.randint(-3, 3)/100), scale * (0.5 + np.random.randint(-2, 2)/100)], # 右肩 + [scale * (0.2 + np.random.randint(-3, 3)/100), scale * (0.9 + np.random.randint(-2, 2)/100)], # 左脚 + [scale * (0.8 + np.random.randint(-3, 3)/100), scale * (0.9 + np.random.randint(-2, 2)/100)], # 右脚 + ], dtype=np.int32) + + # 将轮廓移动到图像中心 + center_offset_x = W // 2 - scale // 2 + center_offset_y = H // 2 - scale // 2 + points[:, 0] += center_offset_x + points[:, 1] += center_offset_y + + # 绘制人物轮廓 + cv2.fillPoly(frame_mask, [points], 1) + + # 添加更大的随机物体 + for _ in range(np.random.randint(2, 4)): + cx = np.random.randint(10, W-10) + cy = np.random.randint(10, H-10) + radius = np.random.randint(scale//20, scale//10) # 增大随机物体的半径 + cv2.circle(frame_mask, (cx, cy), radius, 1, -1) + + # 应用一些随机变形和模糊,使mask更自然 + if np.random.random() < 0.5: + kernel = np.ones((3, 3), np.uint8) + frame_mask = cv2.morphologyEx(frame_mask, cv2.MORPH_CLOSE, kernel) + + # 复制到所有通道 + video_masks[f, :, :, :] = frame_mask[:, :, np.newaxis] + + # 测试不同的变换参数组合 + test_cases = [ + # 测试全部使用brush变换 + {'p_brush': 1.0, 'p_rect': 0, 'p_ellipse': 0, 'p_circle': 0, 'p_random_brush': 0, 'brush_iterations': 2}, + # 测试全部使用矩形变换 + {'p_brush': 0, 'p_rect': 1.0, 'p_ellipse': 0, 'p_circle': 0, 'p_random_brush': 0}, + # 测试全部使用椭圆变换 + {'p_brush': 0, 'p_rect': 0, 'p_ellipse': 1.0, 'p_circle': 0, 'p_random_brush': 0}, + # 测试全部使用圆形变换 + {'p_brush': 0, 'p_rect': 0, 'p_ellipse': 0, 'p_circle': 1.0, 'p_random_brush': 0}, + # 测试全部使用随机笔画变换 + {'p_brush': 0, 'p_rect': 0, 'p_ellipse': 0, 'p_circle': 0, 'p_random_brush': 1.0}, + # 测试均匀分布(包含随机笔画) + {'p_brush': 0.2, 'p_rect': 0.2, 'p_ellipse': 0.2, 'p_circle': 0.2, 'p_random_brush': 0.2}, + ] + + for i, params in enumerate(test_cases): + print(f"\n测试用例 {i + 1}:") + print(f"参数: {params}") + + transformed_masks = transform_video_masks( + video_masks, + **params + ) + + # 基本断言测试 + assert transformed_masks.shape == video_masks.shape, "转换后的mask形状应该保持不变" + assert transformed_masks.dtype == video_masks.dtype, "转换后的mask数据类型应该保持不变" + assert np.any(transformed_masks != video_masks), "转换后的mask应该与原始mask不同" + + # 保存结果 + save_test_results(video_masks, transformed_masks, i) + +def save_test_results(original_masks, transformed_masks, test_case_index): + """保存测试结果为图片""" + import matplotlib.pyplot as plt + + # 创建一个图像网格来显示原始mask和转换后的mask + fig, axes = plt.subplots(2, 5, figsize=(15, 6)) + + # 显示5帧原始mask + for i in range(5): + axes[0, i].imshow(original_masks[i, :, :, 0], cmap='gray') + axes[0, i].set_title(f'Original Frame {i}') + axes[0, i].axis('off') + + # 显示5帧转换后的mask + for i in range(5): + axes[1, i].imshow(transformed_masks[i, :, :, 0], cmap='gray') + axes[1, i].set_title(f'Transformed Frame {i}') + axes[1, i].axis('off') + + plt.tight_layout() + plt.savefig(f'mask_transform_test_{test_case_index}.png') + plt.close() + +if __name__ == "__main__": + test_transform_video_masks() + print("所有测试完成!") \ No newline at end of file diff --git a/train/train.py b/train/train.py new file mode 100755 index 0000000..e42ee3d --- /dev/null +++ b/train/train.py @@ -0,0 +1,2024 @@ +# Standard library imports +import os +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +import argparse +import logging +import math +import shutil +import json +import random +import time +import gc +from pathlib import Path +from typing import List, Optional, Tuple, Union +from multiprocessing import Pool, cpu_count + +# Third-party imports +import numpy as np +import pandas as pd +import cv2 +from PIL import Image +from tqdm.auto import tqdm + +# PyTorch imports +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset +import torchvision.transforms as TT +from torchvision import transforms +from torchvision.transforms.functional import center_crop, resize +from torchvision.transforms import InterpolationMode + +# Hugging Face imports +import transformers +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer +from huggingface_hub import create_repo, upload_folder +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict + +# Diffusers imports +import diffusers +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, + CogvideoXBranchModel, + CogVideoXI2VDualInpaintPipeline +) +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.optimization import get_scheduler +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.training_utils import ( + cast_training_params, + clear_objs_and_retain_memory, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + export_to_video, + is_wandb_available, + load_image, + load_video +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module +from diffusers.utils.import_utils import is_xformers_available + +from mask_process import transform_video_masks + +if is_wandb_available(): + import wandb + +# Configure logging +logger = get_logger(__name__) + +def get_gradient_norm(parameters): + norm = 0 + for param in parameters: + if param.grad is None: + continue + local_norm = param.grad.detach().data.norm(2) + norm += local_norm.item() ** 2 + norm = norm**0.5 + return norm + +def concatenate_images_horizontally(images1, images2, images3, output_type="np"): + ''' + Concatenate three lists of images horizontally. + Args: + images1: List[Image.Image] or List[np.ndarray] + images2: List[Image.Image] or List[np.ndarray] + images3: List[Image.Image] or List[np.ndarray] + Returns: + List[Image.Image] or List[np.ndarray] + ''' + concatenated_images = [] + for img1, img2, img3 in zip(images1, images2, images3): + # Convert images to numpy arrays + arr1 = np.array(img1) + arr2 = np.array(img2) + arr3 = np.array(img3) + + # Concatenate arrays horizontally + concatenated_img = np.concatenate((arr1, arr2, arr3), axis=1) + + # Convert back to PIL Image + if output_type == "pil": + concatenated_img = Image.fromarray(concatenated_img) + elif output_type == "np": + pass + else: + raise NotImplementedError + concatenated_images.append(concatenated_img) + return concatenated_images + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for VideoPainter.") + + # Model information + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + # Dataset information + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--condition_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--val_condition_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--video_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + # Validation + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." + ), + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + + # Training information + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=128, + help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip videos horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--validating_steps", + type=int, + default=50, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + + # Optimizer + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + + # Other information + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument("--runs_name", type=str, default=None, help="Runs name") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--cogvideox_branch_name_or_path", + type=str, + default=None, + help="The path to a pre-trained CogVideoX branch model to use for training.", + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--meta_file_path", + type=str, + default=None, + help="The path to meta data.", + ) + parser.add_argument( + "--val_meta_file_path", + type=str, + default=None, + help="The path to meta data.", + ) + parser.add_argument( + "--corrupt_file_path", + type=str, + default=None, + help="The path to corrupt data.", + ) + parser.add_argument( + "--random_mask", + action="store_true", + help=( + "Training CogVideoX with random mask" + ), + ) + parser.add_argument( + "--resolution", + type=int, + nargs=2, + default=[480, 720], + help=( + "The resolution for input videos, all the videos in the train/validation dataset will be resized to this" + " resolution. Provide two integers: height and width." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--max_text_seq_length", + type=int, + default=226, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # branch + parser.add_argument( + "--branch_layer_num", + type=int, + default=4, + help="Number of layers in the branch.", + ) + parser.add_argument( + "--add_first", + action='store_true', + help="Enable add_first feature. Default is False.", + ) + parser.add_argument( + "--mask_add", + action='store_true', + help="Enable mask_add feature. Default is False.", + ) + parser.add_argument( + "--mask_background", + action='store_true', + help="Enable mask_background feature. Default is False.", + ) + parser.add_argument( + "--wo_text", + action='store_true', + help="Enable wo_text feature. Default is False.", + ) + parser.add_argument( + "--inpainting_loss_weight", + type=float, + default=1.0, + help="The weight of inpainting loss.", + ) + parser.add_argument( + "--mix_train_ratio", + type=float, + default=0.0, + help="The ratio of mix training of images and videos.", + ) + parser.add_argument( + "--first_frame_gt", + action='store_true', + help="Enable first_frame_gt feature. Default is False.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + parser.add_argument( + "--noised_image_dropout", + type=float, + default=0.05, + help="Image condition dropout probability when finetuning image-to-video.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default="resize", + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'resize']", + ) + parser.add_argument( + "--mask_transform_prob", + type=float, + default=0.7, + help="Probability of transforming each SAM mask", + ) + parser.add_argument( + "--p_brush", + type=float, + default=0.3, + help="Probability of using brush mask", + ) + parser.add_argument( + "--p_rect", + type=float, + default=0.3, + help="Probability of using rectangle mask", + ) + parser.add_argument( + "--p_ellipse", + type=float, + default=0.2, + help="Probability of using ellipse mask", + ) + parser.add_argument( + "--p_circle", + type=float, + default=0.2, + help="Probability of using circle mask", + ) + parser.add_argument( + "--p_random_brush", + type=float, + default=0.0, + help="Probability of using random brush mask", + ) + parser.add_argument( + "--margin_ratio", + type=float, + default=0.1, + help="Margin ratio for mask boundary perturbation", + ) + parser.add_argument( + "--shape_scale_min", + type=float, + default=1.1, + help="Minimum scale ratio for shape transformation", + ) + parser.add_argument( + "--shape_scale_max", + type=float, + default=1.5, + help="Maximum scale ratio for shape transformation", + ) + return parser.parse_args() + +class VideoInpaintingDataset(Dataset): + def __init__( + self, + meta_file_path: Optional[str] = None, + condition_data_root: Optional[str] = None, + video_data_root: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_config_name: Optional[str] = None, + height: int = 480, + width: int = 720, + fps: int = 8, + max_num_frames: int = 49, + skip_frames_start: int = 0, + skip_frames_end: int = 0, + cache_dir: Optional[str] = None, + id_token: Optional[str] = None, + is_train: bool = True, + load_mode: str = "online", + ) -> None: + super().__init__() + self.meta_file_path = Path(meta_file_path) + self.meta_file_path_str = meta_file_path + self.condition_data_root = Path(condition_data_root) if condition_data_root is not None else None + self.video_data_root = Path(video_data_root) if video_data_root is not None else None + self.dataset_name = dataset_name + self.dataset_config_name = dataset_config_name + self.height = height + self.width = width + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frames_start = skip_frames_start + self.skip_frames_end = skip_frames_end + self.cache_dir = cache_dir + self.id_token = id_token or "" + self.is_train = is_train + self.load_mode = load_mode + if dataset_name is not None: + self._load_dataset_from_hub() + else: + self._load_dataset_from_local_path() + + def __len__(self): + return len(self.videos) + + def __getitem__(self, index): + if index in self.bad_index: + print(f"Bad index: {index}") + return self.__getitem__(random.randint(0, len(self.videos) - 1)) + if self.load_mode == "online": + + prompt = self.prompts[index] + video_frames_paths = self.videos[index] + condition_video_frames_paths = self.condition_videos[index] + scene_idx = self.scene_idxs[index] + + try: + + video_frames = [cv2.imread(str(frame_path)) for frame_path in video_frames_paths] + condition_video_frames = [cv2.imread(str(frame_path)) for frame_path in condition_video_frames_paths] + + # transform to BGR + # video_frames = [cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) for frame in video_frames] + # condition_video_frames = [cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) for frame in condition_video_frames] + + except Exception as e: + print(f"Error loading video frames: {e}") + self.bad_index.append(index) + return self.__getitem__(random.randint(0, len(self.videos) - 1)) + + + video = np.array(video_frames) + condition_video = np.array(condition_video_frames) + + # Align to the 8 fps + # video = video[::int(fps//self.fps)] + # condition_video = condition_video[::int(fps//self.fps)] + + return { + "prompt": prompt, + "video": video, + "condition_video": condition_video, + "scene_idx": scene_idx, + } + else: + raise NotImplementedError + + def _load_dataset_from_hub(self): + raise NotImplementedError + + def _load_dataset_from_local_path(self): + ''' + read the meta file and corrupt file, and return the metas + ''' + if not self.meta_file_path.exists(): + raise ValueError(f"Meta file does not exist: {self.meta_file_path}") + if not self.video_data_root.exists(): + raise ValueError(f"Instance videos root folder does not exist: {self.video_data_root}") + if not self.condition_data_root.exists(): + raise ValueError(f"Condition videos root folder does not exist: {self.condition_data_root}") + + if not self.is_train: + with open('/lpai/volumes/ad-vla-vol-ga/chenantong/codes/balanced_1k_selected_idx.txt', 'r') as f: + self.selected_idx = [int(x) for x in f.read().split(',')] + + try: + metadata = json.load(open(self.meta_file_path, "r")) + except: + metadata = [] + with open(self.meta_file_path, 'r') as f: + for line in f: + metadata.append(json.loads(line)) + + + self.prompts = ["Moving view of a city scene." for _ in range(len(metadata))] + # self.videos = [[self.video_data_root / gt_frame_path for gt_frame_path in sample["gt_frames"]] for sample in metadata] + # self.condition_videos = [[self.condition_data_root / render_frame_path for render_frame_path in sample["render_frames"]] for sample in metadata] + # self.scene_idxs = [sample["scene_idx"] for sample in metadata] + self.bad_index = [] + + self.videos = [] + self.condition_videos = [] + self.scene_idxs = [] + for sample in metadata: + if 'waymo' in sample['gt_frames'][0]: # from waymo + self.videos.append([gt_frame_path for gt_frame_path in sample["gt_frames"]]) + self.condition_videos.append([render_frame_path for render_frame_path in sample["render_frames"]]) + self.scene_idxs.append(sample["scene_idx"]) + print(f"waymo: {sample['render_frames'][0]}") + else: + self.videos.append([self.video_data_root / gt_frame_path for gt_frame_path in sample["gt_frames"]]) + self.condition_videos.append([self.condition_data_root / render_frame_path for render_frame_path in sample["render_frames"]]) + self.scene_idxs.append(sample["scene_idx"]) + + if not self.is_train: + self.prompts = [p for i, p in enumerate(self.prompts) if i in self.selected_idx] + self.videos = [v for i, v in enumerate(self.videos) if i in self.selected_idx] + self.condition_videos = [c for i, c in enumerate(self.condition_videos) if i in self.selected_idx] + self.scene_idxs = [s for i, s in enumerate(self.scene_idxs) if i in self.selected_idx] + + +class MyWebDataset(): + def __init__(self,resolution,tokenizer,random_mask, max_num_frames, max_sequence_length, proportion_empty_prompts, is_train=True, mask_background=False, mix_train_ratio=0.0, first_frame_gt=False, video_reshape_mode="resize", random_flip: Optional[float] = None, mask_transform_prob=0.7, p_brush=0.3, p_rect=0.3, p_ellipse=0.2, p_circle=0.2, p_random_brush=0.0, margin_ratio=0.1, shape_scale_min=1.1, shape_scale_max=1.5): + self.resolution = resolution + self.tokenizer = tokenizer + self.random_mask = random_mask + self.max_num_frames = max_num_frames + self.val_max_num_frames = max_num_frames + self.max_sequence_length = max_sequence_length + self.proportion_empty_prompts = proportion_empty_prompts + self.mask_background = mask_background + self.is_train = is_train + self.mix_train_ratio = mix_train_ratio + self.first_frame_gt = first_frame_gt + self.video_reshape_mode = video_reshape_mode + + self.resolutions = (max_num_frames, resolution[0], resolution[1]) + + + def tokenize_captions(self, caption, is_train=True): + if random.random() < self.proportion_empty_prompts: + caption="" + elif isinstance(caption, str): + caption=caption + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + caption=random.choice(caption) if is_train else caption[0] + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = self.tokenizer( + caption, max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + def _find_nearest_resolution(self, height, width): + # nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + nearest_res = [self.resolutions[0], self.resolutions[1], self.resolutions[2]] + return nearest_res[1], nearest_res[2] + + def _resize_for_rectangle_crop(self, arr, image_size): + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + def __call__(self, examples): + pixel_values=[] + conditioning_pixel_values=[] + input_ids=[] + + if self.is_train: + + for example in examples: + caption=example["prompt"] + video=example["video"] # frame, height, width, c + condition_video=example["condition_video"] # frame, height, width, c + frame, height, width, c = video.shape + + if frame > self.max_num_frames: + begin_idx = random.randint(0, frame - self.max_num_frames) + end_idx = begin_idx + self.max_num_frames + video = video[begin_idx:end_idx] + condition_video = condition_video[begin_idx:end_idx] + frame = end_idx - begin_idx + elif frame <= self.max_num_frames: + remainder = (3 + (frame % 4)) % 4 + if remainder != 0: + video = video[:-remainder] + condition_video = condition_video[:-remainder] + frame = video.shape[0] + # important, pad to 49 frames + num_repeats = self.max_num_frames - frame + last_frame = video[-1:] + repeated_frames = np.repeat(last_frame, num_repeats, axis=0) + video = np.concatenate([video, repeated_frames], axis=0) + + last_condition_frame = condition_video[-1:] + repeated_condition_frames = np.repeat(last_condition_frame, num_repeats, axis=0) + condition_video = np.concatenate([condition_video, repeated_condition_frames], axis=0) + + assert video.shape[0] == condition_video.shape[0], "video and condition_video shape is not equal" + + assert video.shape[0] == self.max_num_frames, "video shape is not equal to max_num_frames" + assert condition_video.shape[0] == self.max_num_frames, "condition_video shape is not equal to max_num_frames" + + + video = torch.from_numpy(video).permute(0, 3, 1, 2) + condition_video = torch.from_numpy(condition_video).permute(0, 3, 1, 2) + nearest_res = self._find_nearest_resolution(video.shape[2], video.shape[3]) + if self.video_reshape_mode == 'center': + video_resized = self._resize_for_rectangle_crop(video, nearest_res) + condition_video_resized = self._resize_for_rectangle_crop(condition_video, nearest_res) + elif self.video_reshape_mode == 'resize': + video_resized = torch.stack([resize(frame, nearest_res) for frame in video], dim=0) + condition_video_resized = torch.stack([resize(frame, nearest_res) for frame in condition_video], dim=0) + else: + raise NotImplementedError + video = video_resized + condition_video = condition_video_resized + + video = video.permute(0, 2, 3, 1).numpy() + condition_video = condition_video.permute(0, 2, 3, 1).numpy() + + for i in range(len(video)): + video[i] = cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB) + video = (video.astype(np.float32) / 127.5) - 1.0 + + for i in range(len(condition_video)): + condition_video[i] = cv2.cvtColor(condition_video[i], cv2.COLOR_BGR2RGB) + condition_video = (condition_video.astype(np.float32) / 127.5) - 1.0 + + + if random.random() < self.mix_train_ratio: + video, condition_video = video[:1], condition_video[:1] + else: + if self.first_frame_gt: + video[0] = condition_video[0] + + pixel_values.append(torch.tensor(video).permute(0, 3, 1, 2)) + conditioning_pixel_values.append(torch.tensor(condition_video).permute(0, 3, 1, 2)) + input_ids.append(self.tokenize_captions(caption)[0]) + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + conditioning_pixel_values = torch.stack(conditioning_pixel_values) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack(input_ids) + + + else: + for example in examples: + caption=example["prompt"] + video=example["video"] # frame, height, width, c + condition_video=example["condition_video"] # frame, height, width, c + frame, height, width, c = video.shape + assert condition_video.shape[0] == frame, "condition_video shape is not equal to video shape" + + if frame > self.val_max_num_frames: + begin_idx = 0 + end_idx = begin_idx + self.val_max_num_frames + condition_video = condition_video[begin_idx:end_idx] + video = video[begin_idx:end_idx] + frame = end_idx - begin_idx + elif frame <= self.val_max_num_frames: + remainder = (3 + (frame % 4)) % 4 + if remainder != 0: + video = video[:-remainder] + condition_video = condition_video[:-remainder] + frame = video.shape[0] + + # important, pad to 49 frames + num_repeats = self.max_num_frames - frame + last_frame = video[-1:] + repeated_frames = np.repeat(last_frame, num_repeats, axis=0) + video = np.concatenate([video, repeated_frames], axis=0) + + last_condition_frame = condition_video[-1:] + repeated_condition_frames = np.repeat(last_condition_frame, num_repeats, axis=0) + condition_video = np.concatenate([condition_video, repeated_condition_frames], axis=0) + + assert video.shape[0] == condition_video.shape[0], "video and condition_video shape is not equal" + + assert video.shape[0] == self.max_num_frames, "video shape is not equal to max_num_frames" + assert condition_video.shape[0] == self.max_num_frames, "condition_video shape is not equal to max_num_frames" + + # resize video and condition_video to the nearest resolution + video = torch.from_numpy(video).permute(0, 3, 1, 2) + condition_video = torch.from_numpy(condition_video).permute(0, 3, 1, 2) + nearest_res = self._find_nearest_resolution(video.shape[2], video.shape[3]) + if self.video_reshape_mode == 'center': + video_resized = self._resize_for_rectangle_crop(video, nearest_res) + condition_video_resized = self._resize_for_rectangle_crop(condition_video, nearest_res) + elif self.video_reshape_mode == 'resize': + video_resized = torch.stack([resize(frame, nearest_res) for frame in video], dim=0) + condition_video_resized = torch.stack([resize(frame, nearest_res) for frame in condition_video], dim=0) + else: + raise NotImplementedError + video = video_resized + condition_video = condition_video_resized + + video = video.permute(0, 2, 3, 1).numpy() + condition_video = condition_video.permute(0, 2, 3, 1).numpy() + + assert len(condition_video) == len(video), f"condition_video shape is not equal to video shape for {example['prompt']}" + + if self.first_frame_gt: + condition_video[0] = video[0] + + video_ = [Image.fromarray(cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB)) for i in range(len(video))] + condition_video_ = [Image.fromarray(cv2.cvtColor(condition_video[i], cv2.COLOR_BGR2RGB)) for i in range(len(condition_video))] + + pixel_values.append(video_) + conditioning_pixel_values.append(condition_video_) + input_ids.append(caption) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + } + + +def log_validation( + pipe, + args, + accelerator, + pipeline_args, + epoch, + validating_step=0, + is_final_validation: bool = False, +): + logger.info(f"Running validation...") + scheduler_args = {} + + if "variance_type" in pipe.scheduler.config: + variance_type = pipe.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) + pipe = pipe.to(accelerator.device) + # pipe.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipe.enable_xformers_memory_efficient_attention() + + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + pipeline_args["prompt"] = pipeline_args["prompt"][0] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, num_inference_steps=50, generator=generator, output_type="np").frames[0] + # Concatenate images horizontally + original_video = pipe.video_processor.preprocess_video(pipeline_args['video'][0], height=video.shape[1], width=video.shape[2]) + condition_video = pipe.video_processor.preprocess_video(pipeline_args['condition_video'][0], height=video.shape[1], width=video.shape[2]) + original_video = pipe.video_processor.postprocess_video(video=original_video, output_type="np")[0] + condition_video = pipe.video_processor.postprocess_video(video=condition_video, output_type="np")[0] + video_ = concatenate_images_horizontally( + original_video, + condition_video, + video + ) + videos.append(video_) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else f"validation_{epoch + 1:05d}_{validating_step}" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{pipeline_args['replace_gt']}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + clear_objs_and_retain_memory([pipe]) + + del pipe + torch.cuda.empty_cache() + gc.collect() + + return videos + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + if text_input_ids is None: + batch_size = len(prompt) + else: + batch_size = text_input_ids.shape[0] + + + if tokenizer is not None and text_input_ids is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False, token_ids=None +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=token_ids, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=token_ids, + ) + return prompt_embeds + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy"] + if args.optimizer not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and (args.optimizer.lower() not in ["adam", "adamw"]): + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if args.optimizer.lower() == "adamw": + optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "adam": + optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + return optimizer + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + # debug + from accelerate import DeepSpeedPlugin + deepspeed_plugin = DeepSpeedPlugin( + hf_ds_config='ds_config.json' + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + deepspeed_plugin=deepspeed_plugin # debug + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + logger.info("Initializing CogVideoX branch weights from transformer") + branch = CogvideoXBranchModel.from_transformer( + transformer=transformer, + num_layers=args.branch_layer_num, + attention_head_dim=transformer.config.attention_head_dim, + num_attention_heads=transformer.config.num_attention_heads, + load_weights_from_transformer=True, + wo_text=args.wo_text, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional context encoder branch + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + branch.requires_grad_(True) + branch.train() + + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + branch.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + transformer.enable_xformers_memory_efficient_attention() + branch.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + branch.enable_gradient_checkpointing() + transformer.enable_gradient_checkpointing() + + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(branch))): + model: CogvideoXBranchModel + model = unwrap_model(model) + model.save_pretrained( + os.path.join(output_dir, "branch"), safe_serialization=True, max_shard_size="5GB" + ) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + if weights: + weights.pop() + + + def load_model_hook(models, input_dir): + while len(models) > 0: + model = models.pop() + + load_model = CogvideoXBranchModel.from_pretrained(input_dir, subfolder="branch") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([model]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + cast_training_params([branch], dtype=torch.float32) + + # Optimization parameters + branch_parameters = list(filter(lambda p: p.requires_grad, branch.parameters())) + branch_parameters_with_lr = {"params": branch_parameters, "lr": args.learning_rate} + params_to_optimize = [branch_parameters_with_lr] + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + logger.info(f"use_deepspeed_optimizer: {use_deepspeed_optimizer}, use_deepspeed_scheduler: {use_deepspeed_scheduler}") + + optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) + + # Dataset and DataLoader + train_dataset = VideoInpaintingDataset( + meta_file_path=args.meta_file_path, + condition_data_root=args.condition_data_root, + video_data_root=args.video_data_root, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + height=args.height, + width=args.width, + fps=args.fps, + max_num_frames=args.max_num_frames, + skip_frames_start=args.skip_frames_start, + skip_frames_end=args.skip_frames_end, + cache_dir=args.cache_dir, + id_token=args.id_token, + ) + validation_dataset = VideoInpaintingDataset( + meta_file_path=args.val_meta_file_path, + condition_data_root=args.val_condition_data_root, + video_data_root=args.video_data_root, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + height=args.height, + width=args.width, + fps=args.fps, + max_num_frames=args.max_num_frames, + skip_frames_start=args.skip_frames_start, + skip_frames_end=args.skip_frames_end, + cache_dir=args.cache_dir, + id_token=args.id_token, + is_train=False, + ) + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=MyWebDataset( + resolution = args.resolution, + tokenizer=tokenizer, + random_mask=args.random_mask, + max_num_frames=args.max_num_frames, + max_sequence_length=args.max_text_seq_length, + proportion_empty_prompts=args.proportion_empty_prompts, + mask_background=args.mask_background, + mix_train_ratio=args.mix_train_ratio, + first_frame_gt=args.first_frame_gt, + is_train=True, + ), + pin_memory=args.pin_memory, + num_workers=args.dataloader_num_workers, + persistent_workers=False + ) + + validation_dataloader = DataLoader( + validation_dataset, + batch_size=1, + shuffle=True, + collate_fn=MyWebDataset( + resolution=args.resolution, + tokenizer=tokenizer, + random_mask=args.random_mask, + max_num_frames=args.max_num_frames, + max_sequence_length=args.max_text_seq_length, + proportion_empty_prompts=args.proportion_empty_prompts, + mask_background=args.mask_background, + first_frame_gt=args.first_frame_gt, + is_train=False, + ), + pin_memory=args.pin_memory, + num_workers=args.dataloader_num_workers, + persistent_workers=False + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + branch, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + branch, optimizer, train_dataloader, lr_scheduler + ) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = args.tracker_name or "VideoPainter" + accelerator.init_trackers( + project_name=tracker_name, + config=vars(args), + init_kwargs={"wandb": {"name": args.runs_name}} + ) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + logger.info("***** Running training *****") + logger.info(f" Num trainable parameters = {num_trainable_parameters}") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + logger.info(f"Resuming from checkpoint {path}") + try: + accelerator.load_state(os.path.join(args.output_dir, path)) + except: + logger.info(f'failed loading from {os.path.join(args.output_dir, path)}') + accelerator.load_state(args.resume_from_checkpoint) + logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}") + + global_step = int(path.split("-")[-1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + vae_scale_factor_temporal = vae.config.temporal_compression_ratio + + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + for epoch in range(first_epoch, args.num_train_epochs): + branch.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [branch] + + with accelerator.accumulate(models_to_accumulate): + if global_step < 1: + pixel_values = (batch["pixel_values"][0, 0].permute(1, 2, 0).cpu().numpy() + 1) * 127.5 + conditioning_pixel_values = (batch["conditioning_pixel_values"][0, 0].permute(1, 2, 0).cpu().numpy() + 1) * 127.5 + pixel_values = pixel_values.clip(0, 255).astype(np.uint8) + conditioning_pixel_values = conditioning_pixel_values.clip(0, 255).astype(np.uint8) + pixel_values = cv2.cvtColor(pixel_values, cv2.COLOR_BGR2RGB) + conditioning_pixel_values = cv2.cvtColor(conditioning_pixel_values, cv2.COLOR_BGR2RGB) + + combined_image = np.hstack((pixel_values, conditioning_pixel_values)) + cv2.imwrite(f'{args.runs_name}_training_sample_1_{global_step}.png', combined_image) + + if batch["pixel_values"].shape[1] > 1: + pixel_values = (batch["pixel_values"][0, 1].permute(1, 2, 0).cpu().numpy() + 1) * 127.5 + conditioning_pixel_values = (batch["conditioning_pixel_values"][0, 1].permute(1, 2, 0).cpu().numpy() + 1) * 127.5 + + pixel_values = pixel_values.clip(0, 255).astype(np.uint8) + conditioning_pixel_values = conditioning_pixel_values.clip(0, 255).astype(np.uint8) + + pixel_values = cv2.cvtColor(pixel_values, cv2.COLOR_BGR2RGB) + conditioning_pixel_values = cv2.cvtColor(conditioning_pixel_values, cv2.COLOR_BGR2RGB) + + combined_image = np.hstack((pixel_values, conditioning_pixel_values)) + cv2.imwrite(f'{args.runs_name}_training_sample_2_{global_step}.png', combined_image) + + images = batch["pixel_values"][:, :1, :, :, :].permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + image_noise_sigma = torch.normal( + mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype + ) + image_noise_sigma = torch.exp(image_noise_sigma) + noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] + image_latents = vae.encode(noisy_images.to(dtype=weight_dtype)).latent_dist.sample() + image_latents = image_latents.permute(0, 2, 1, 3, 4) * vae.config.scaling_factor # [B, F, C, H, W] + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + model_input = vae.encode(batch["pixel_values"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype)).latent_dist.sample() + model_input = model_input.permute(0, 2, 1, 3, 4) * vae.config.scaling_factor # [B, F, C, H, W] + model_input = model_input.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + + conditioning_latents=vae.encode(batch["conditioning_pixel_values"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype)).latent_dist.sample() + conditioning_latents = conditioning_latents.permute(0, 2, 1, 3, 4) * vae.config.scaling_factor # [B, F, C, H, W] + conditioning_latents = conditioning_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + padding_shape = (conditioning_latents.shape[0], conditioning_latents.shape[1] - 1, *conditioning_latents.shape[2:]) + latent_padding = image_latents.new_zeros(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if random.random() < args.noised_image_dropout: + image_latents = torch.zeros_like(image_latents) + + torch.cuda.empty_cache() + + # conditioning_latents=torch.concat([conditioning_latents, masks], -3).to(dtype=weight_dtype) # we do not use masks + prompts = batch["prompts"] if batch['input_ids'] is None else 'xx' + + # encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + args.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + token_ids=batch["input_ids"], + ) + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input).to(dtype=weight_dtype) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device + ) + timesteps = timesteps.long() + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=args.height, + width=args.width, + num_frames=num_frames, + vae_scale_factor_spatial=vae_scale_factor_spatial, + patch_size=model_config.patch_size, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_video_latents = scheduler.add_noise(model_input, noise, timesteps) + noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) + + + # logger.info(f"noisy_model_input.shape: {noisy_model_input.shape}; prompt_embeds.shape: {prompt_embeds.shape}") + branch_block_samples = branch( + hidden_states=noisy_video_latents, + encoder_hidden_states=prompt_embeds, + branch_cond=conditioning_latents, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + mask_add=args.mask_add, + wo_text=args.wo_text, + return_dict=False, + )[0] + branch_block_samples = [block_sample.to(dtype=weight_dtype) for block_sample in branch_block_samples] + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + branch_block_samples=branch_block_samples, + branch_block_masks=masks if args.mask_add else None, + add_first=args.add_first, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + loss = loss.mean() + # inpainting_loss = torch.mean((weights * (model_pred*masks - target*masks) ** 2).reshape(batch_size, -1), dim=1) + # inpainting_loss = inpainting_loss.mean() + # loss = loss + args.inpainting_loss_weight * inpainting_loss + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = branch.parameters() + gradient_norm_before_clip = get_gradient_norm(params_to_clip) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(params_to_clip) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + torch.cuda.empty_cache() + + if global_step % args.checkpointing_steps == 0: + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + + + # logs = {"loss": loss.detach().item(), "inpainting_loss": inpainting_loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + if accelerator.distributed_type != DistributedType.DEEPSPEED: + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if (global_step % args.validating_steps == 0): # or (global_step == 1): + pipe = CogVideoXI2VDualInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), + branch=unwrap_model(branch), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + vali_num = 0 + for step, batch in enumerate(validation_dataloader): + pipeline_args = { + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "image": batch["pixel_values"][0][0], + "video": batch["pixel_values"], + "prompt": batch["input_ids"], + "condition_video": batch["conditioning_pixel_values"], + "num_frames": np.array(batch['pixel_values'][0]).shape[0], + "strength": 1.0, + "mask_background": args.mask_background, + "add_first": args.add_first, + "wo_text": args.wo_text, + "mask_add": args.mask_add, + "replace_gt": (vali_num+1) % 2 == 1, + } + + validation_outputs = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + validating_step=global_step, + ) + del batch + torch.cuda.empty_cache() + gc.collect() + vali_num += 1 + if vali_num >= 2: + break + if ((epoch + 1) % args.validation_epochs == 0): # or (epoch == 0) + try: + pipe = CogVideoXI2VDualInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), + branch=unwrap_model(branch), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + for step, batch in enumerate(validation_dataloader): + if step < 1: + try: + pipeline_args = { + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "image": batch["pixel_values"][0][0], + "video": batch["pixel_values"], + "prompt": batch["input_ids"], + "condition_video": batch["conditioning_pixel_values"], + "num_frames": np.array(batch['pixel_values'][0]).shape[0], + "strength": 1.0, + "mask_background": args.mask_background, + "add_first": args.add_first, + "mask_add": args.mask_add, + "replace_gt": (vali_num+1) % 2 == 1, + } + validation_outputs = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + validating_step=global_step, + ) + except Exception as e: + logger.error(f"Error during validation step: {e}") + continue + + del batch + torch.cuda.empty_cache() + gc.collect() + + if step >= 1: + break + + del pipe + torch.cuda.empty_cache() + gc.collect() + + except Exception as e: + logger.error(f"Error during validation: {e}") + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args)